대규모 언어 모델의 사실성 향상 기법: 미세 조정 및 추론 전략

사전 훈련 단계에서 대규모 언어 모델(LLM)이 습득하는 사실 지식은 파편화되어 분포되어 있거나 훈련 데이터의 통계적 편향에 영향을 받기 쉽습니다. 감독 미세 조정(SFT)과 정렬 훈련은 이러한 지식 표현을 재구성하는 데 중요한 기회를 제공합니다. 특정 구조적 제약 하에 파라미터 공간의 지식 분포를 조정함으로써 모델은 더 견고한 사실 관계를 구축하고, 동시에 자신의 지식 경계에 대한 메타인지 능력을 개발할 수 있습니다.

1. 감독 미세 조정 및 정렬 훈련을 통한 사실성 최적화

1.1 도메인별 미세 조정 접근 방식

사실성 미세 조정의 핵심 과제는 정확한 진술과 그럴듯한 허위 진술을 구별할 수 있는 훈련 신호를 생성하는 것입니다. 일반적인 지시 미세 조정과 달리, 사실성 최적화는 명시적으로 사실 검증 과정을 모델링하고 손실 함수에 사실 일관성 제약을 도입해야 합니다.

1.1.1 진실성 지시 미세 조정

진실성 지시 미세 조정은 훈련 데이터의 구성을 재정의하여, 모델이 생성 과정에서 사실 확인과 관련된 주의 패턴을 활성화하도록 유도합니다. 이 전략은 단순히 답변의 정확성뿐만 아니라 추론 경로와 사실적 근거의 명시적 연관성을 강조합니다. 훈련 데이터는 '경계 사례', 즉 의미론적으로는 타당하지만 사실과 일치하지 않는 진술을 포함하도록 적대적 필터링 원칙에 따라 구축되어야 합니다. 이는 모델이 미묘한 사실적 차이에 대한 민감도를 높이는 데 기여합니다.

구현: 진실성 지시 미세 조정 데이터 구축 및 훈련

아래 스크립트는 사실성 선호에 기반한 지시 미세 조정 워크플로우를 구현합니다. 여기에는 적대적 음성 샘플 생성, 다중 작업 사실성 목표 설계, 그리고 DeepSpeed ZeRO-3을 지원하는 분산 훈련 설정이 포함됩니다.


import json
import torch
import argparse
from dataclasses import dataclass
from typing import List, Dict, Optional
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
    Trainer, DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model

@dataclass
class KnowledgeItem:
    instruction: str
    context: str
    correct_answer: str
    false_answer: Optional[str] = None  # 적대적 음성 샘플
    supporting_details: Optional[str] = None  # 사실을 뒷받침하는 정보
    uncertain_trigger: Optional[str] = None  # 거절 인식 훈련용

class KnowledgeAlignmentDataset(Dataset):
    def __init__(self, data_file: str, tokenizer, max_seq_len: int = 2048):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.examples = self._read_data(data_file)

    def _read_data(self, path: str) -> List[KnowledgeItem]:
        items = []
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                items.append(KnowledgeItem(**data))
        return items

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.examples[idx]

        # 사실성 제약이 있는 질의 형식 생성
        query_text = f"""### 지시:
{item.instruction}

### 입력 정보:
{item.context}

### 근거:
{item.supporting_details if item.supporting_details else '특정 근거 없음.'}

### 응답:
"""

        # 입력 인코딩
        query_enc = self.tokenizer(query_text, truncation=True, max_length=self.max_seq_len // 2)
        response_enc = self.tokenizer(
            item.correct_answer + self.tokenizer.eos_token,
            truncation=True,
            max_length=self.max_seq_len // 2
        )

        input_ids = query_enc['input_ids'] + response_enc['input_ids']
        attention_mask = query_enc['attention_mask'] + response_enc['attention_mask']

        # 레이블 구성: 응답 부분에 대해서만 손실 계산 (지시 및 입력 부분은 -100으로 마스킹)
        labels = [-100] * len(query_enc['input_ids']) + response_enc['input_ids']

        # 길이 정렬 및 자르기
        current_len = min(len(input_ids), self.max_seq_len)
        input_ids = input_ids[:current_len]
        attention_mask = attention_mask[:current_len]
        labels = labels[:current_len]

        # 패딩
        pad_amount = self.max_seq_len - current_len
        if pad_amount > 0:
            input_ids += [self.tokenizer.pad_token_id] * pad_amount
            attention_mask += [0] * pad_amount
            labels += [-100] * pad_amount

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "contains_false_sample": torch.tensor(item.false_answer is not None, dtype=torch.bool),
            "is_uncertainty_aware_sample": torch.tensor(item.uncertain_trigger is not None, dtype=torch.bool)
        }

class AligningTruthTrainer(Trainer):
    def __init__(self, *args, adversarial_factor: float = 0.1, uncertainty_factor: float = 0.05, **kwargs):
        super().__init__(*args, **kwargs)
        self.adversarial_factor = adversarial_factor
        self.uncertainty_factor = uncertainty_factor

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        contains_false = inputs.pop("contains_false_sample")
        is_uncertainty_aware = inputs.pop("is_uncertainty_aware_sample")

        outputs = model(**inputs)
        logits = outputs.logits

        # 표준 언어 모델링 손실
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        loss_func = torch.nn.CrossEntropyLoss(reduction='none')
        lm_loss_per_token = loss_func(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        lm_loss_per_token = lm_loss_per_token.view(shift_labels.size())

        # 패딩되지 않은 토큰에 대해서만 평균 손실 계산
        valid_token_mask = (shift_labels != -100).float()
        base_loss = (lm_loss_per_token * valid_token_mask).sum() / valid_token_mask.sum()

        total_loss = base_loss

        # 적대적 학습 손실: 올바른 답변에 가깝게, 틀린 답변에서 멀어지도록 유도
        if self.adversarial_factor > 0 and contains_false.any():
            # 실제 구현에서는 negative_sample에 대한 logit을 계산하고 positive_sample과 contrastive loss를 적용해야 함.
            # 여기서는 편의상 간단한 스케일링으로 대체 (실제 복잡한 로직은 생략)
            contrastive_loss = torch.tensor(0.0, device=inputs['input_ids'].device)
            total_loss += self.adversarial_factor * contrastive_loss

        # 거절 인식 손실: 불확실성 표현의 보정 최적화
        if self.uncertainty_factor > 0 and is_uncertainty_aware.any():
            # 불확실성 샘플에 대해 출력 분포의 엔트로피 증가 유도 (과도한 자신감 방지)
            # 여기서는 편의상 간단한 스케일링으로 대체 (실제 복잡한 로직은 생략)
            uncertainty_loss = torch.tensor(0.0, device=inputs['input_ids'].device)
            total_loss += self.uncertainty_factor * uncertainty_loss

        return (total_loss, outputs) if return_outputs else total_loss

def generate_sample_knowledge_data():
    """샘플 데이터 형식 생성"""
    template_fact = {
        "instruction": "다음 역사적 질문에 대해 주어진 사실 근거만을 바탕으로 답변하세요.",
        "context": "전화는 누가 발명했습니까?",
        "correct_answer": "역사 기록에 따르면, 알렉산더 그레이엄 벨(Alexander Graham Bell)은 1876년에 전화에 대한 발명 특허를 받았습니다.",
        "false_answer": "토머스 에디슨이 전화를 발명했습니다.",  # 적대적 음성 샘플
        "supporting_details": "Alexander Graham Bell was awarded the first US patent for the telephone in 1876.",
        "uncertain_trigger": None
    }

    template_uncertainty = {
        "instruction": "다음 전문적인 질문에 답변하세요. 만약 불확실하다면, 명확히 언급하세요.",
        "context": "2024년 특정 소규모 국가의 GDP 성장률은 정확히 얼마입니까?",
        "correct_answer": "현재 해당 국가의 2024년 GDP 성장률에 대한 신뢰할 수 있는 데이터가 없어 정확한 정보를 제공할 수 없습니다.",
        "false_answer": None,
        "supporting_details": None,
        "uncertain_trigger": "정보 없음"
    }

    with open('knowledge_data_template.jsonl', 'w', encoding='utf-8') as f:
        f.write(json.dumps(template_fact, ensure_ascii=False) + '\n')
        f.write(json.dumps(template_uncertainty, ensure_ascii=False) + '\n')

    print("샘플 데이터가 knowledge_data_template.jsonl 에 생성되었습니다.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_id', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--data_source', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='./truth_aligned_output')
    parser.add_argument('--adversarial_weight', type=float, default=0.1)
    args = parser.parse_args()

    # 토크나이저 및 모델 로드
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    base_model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )

    # LoRA 설정: 사실성 최적화를 위해 특정 레이어 조정
    lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        # 사실성 지식은 보통 중간 레이어에 저장되므로, 중간 레이어에 더 높은 드롭아웃을 적용하여 과적합 방지
        layers_to_transform=[i for i in range(8, 24)] if 'llama' in args.model_id.lower() else None
    )

    peft_model = get_peft_model(base_model, lora_config)
    peft_model.print_trainable_parameters()

    # 데이터셋
    dataset_for_training = KnowledgeAlignmentDataset(args.data_source, tokenizer)

    # 훈련 파라미터
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_strategy="epoch",
        fp16=False,
        bf16=True,
        remove_unused_columns=False,
        dataloader_num_workers=4,
        # DeepSpeed 설정 (사용 시)
        deepspeed="ds_config_zero3.json" if torch.cuda.device_count() > 1 else None
    )

    custom_trainer = AligningTruthTrainer(
        model=peft_model,
        args=training_args,
        train_dataset=dataset_for_training,
        tokenizer=tokenizer,
        adversarial_factor=args.adversarial_weight
    )

    custom_trainer.train()
    custom_trainer.save_model(args.output_dir)

if __name__ == "__main__":
    main()
1.1.2 거절 인식 훈련

모델이 환각을 일으키는 근본적인 이유 중 하나는 지식 경계를 효과적으로 인식하지 못하여 정보가 부족한 상황에서도 그럴듯한 내용을 강제로 생성하는 것입니다. 거절 인식 훈련은 불확실성 정량화 목표를 도입하여, 모델이 특정 임계값 이하의 확신도를 가질 때 명확한 거절 표현을 출력하도록 가르칩니다. 여기서 중요한 점은 '답변 가능-답변 불가능' 이진 분류 데이터셋을 구축하고, 표준 언어 모델 헤드가 생성을 담당하고 보조 확신도 헤드가 현재 문맥에서 답변의 적절성을 평가하는 이중 헤드 예측 아키텍처를 사용하는 것입니다.

구현: 거절 인식 훈련을 위한 이중 헤드 아키텍처

이 구현은 보조 확신도 예측 헤드를 도입하고, 보정 손실(calibration loss)과 표준 언어 모델링 목표를 결합하여 모델이 언제 불확실성을 표현해야 하는지 학습하도록 유도합니다.


import torch
import torch.nn as nn
import json
from typing import Dict, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch.nn.functional as F

class UncertaintyPredictor(nn.Module):
    """
    마지막 계층의 히든 상태를 기반으로 하는 경량 확신도 예측기
    """
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.predictor_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()  # [0,1] 범위의 확신도 출력
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 마지막 토큰의 히든 상태 사용
        final_token_hidden = hidden_states[:, -1, :]  # [batch, hidden_dim]
        return self.predictor_layers(final_token_hidden).squeeze(-1)  # [batch]

class BoundaryAwareLM(nn.Module):
    def __init__(self, base_model_id: str, certainty_threshold: float = 0.6):
        super().__init__()
        self.main_lm = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.bfloat16,
            output_hidden_states=True
        )
        self.uncertainty_head = UncertaintyPredictor(self.main_lm.config.hidden_size)
        self.certainty_threshold = certainty_threshold
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 일부 하위 파라미터 동결 (의미적 사실과 더 밀접하게 관련된 상위 레이어와 확신도 헤드만 미세 조정)
        for name, param in self.main_lm.named_parameters():
            if "layers.0" in name or "layers.1" in name or "layers.2" in name:
                param.requires_grad = False

    def forward(self, input_ids, attention_mask=None, labels=None, is_query_answerable=None, **kwargs):
        lm_outputs = self.main_lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            **kwargs
        )

        # 확신도 예측을 위해 마지막 계층의 히든 상태 가져오기
        last_hidden_state = lm_outputs.hidden_states[-1]  # [batch, seq_len, hidden_dim]

        # 확신도 점수 계산
        predicted_confidence = self.uncertainty_head(last_hidden_state)  # [batch]

        total_loss = lm_outputs.loss if labels is not None else None

        # 감독 신호(답변 가능 여부)가 있을 경우 보조 손실 계산
        if is_query_answerable is not None and labels is not None:
            # 답변 가능한 샘플의 확신도는 1에 가까워야 하고, 불가능한 샘플은 0에 가까워야 함
            confidence_loss = F.binary_cross_entropy(
                predicted_confidence,
                is_query_answerable.float()
            )

            # 확신도가 임계값보다 높거나 답변 불가능한 경우에만 생성 손실 계산 (낮은 확신도에서 강제로 생성하는 것 방지)
            valid_generation_mask = (predicted_confidence > self.certainty_threshold) | (is_query_answerable == 0)

            if valid_generation_mask.any() and total_loss is not None:
                # 답변 불가능한 샘플에 대해 특정 거절 템플릿 생성을 강제하는 로직 (복잡하므로 간단히 언급)
                # refusal_template_ids = self.tokenizer("저는 이 질문에 정확히 답변할 충분한 정보가 없습니다.", return_tensors="pt")['input_ids'].to(input_ids.device)
                pass

            if total_loss is not None:
                total_loss = total_loss + 0.5 * confidence_loss  # 가중치 병합

        return {
            'loss': total_loss,
            'logits': lm_outputs.logits,
            'confidence_score': predicted_confidence,
            'past_key_values': lm_outputs.past_key_values
        }

    def generate_with_uncertainty_check(self, input_ids, max_gen_length=512, **kwargs):
        """거절 감지 기능이 있는 생성 메서드"""
        with torch.no_grad():
            generated_output = self.main_lm.generate(
                input_ids,
                max_length=max_gen_length,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs
            )

            # 생성된 시퀀스의 히든 상태를 얻어 최종 확신도 계산
            full_generated_ids = generated_output.sequences
            model_forward_output = self.main_lm(
                full_generated_ids,
                output_hidden_states=True
            )
            last_hidden_state = model_forward_output.hidden_states[-1]
            final_confidence = self.uncertainty_head(last_hidden_state) # [batch]

            refusal_message = "이 질문에 답변할 충분한 정보가 없습니다."
            refusal_ids = self.tokenizer(refusal_message, return_tensors="pt")['input_ids'].to(full_generated_ids.device)

            # 확신도가 임계값 미만인 경우 거절 답변으로 대체 (간단화된 로직)
            for i, conf_score in enumerate(final_confidence):
                if conf_score < self.certainty_threshold:
                    # 실제 애플리케이션에서는 여기서 조기 중단하고 거절 템플릿을 반환해야 함
                    full_generated_ids[i] = refusal_ids.squeeze(0) # 예시적 대체

            return full_generated_ids, final_confidence

class QueryResponseDataset(Dataset):
    def __init__(self, data_file: str, tokenizer, max_len: int = 1024):
        self.tokenizer = tokenizer
        self.max_len = max_len
        with open(data_file, 'r', encoding='utf-8') as f:
            self.data = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # 대화 형식 구성
        question_prompt = f"질문: {item['question']}\n답변: "
        response_text = item['answer'] if item['is_answerable'] else "모르겠습니다."

        complete_text = question_prompt + response_text + self.tokenizer.eos_token
        encoded_output = self.tokenizer(
            complete_text,
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors='pt'
        )

        input_token_ids = encoded_output['input_ids'].squeeze()
        attn_mask = encoded_output['attention_mask'].squeeze()

        # 레이블 구성: 답변 부분에 대해서만 손실 계산
        prompt_encoding = self.tokenizer(question_prompt, truncation=True, max_length=self.max_len)
        prompt_length = len(prompt_encoding['input_ids'])

        labels_tensor = input_token_ids.clone()
        labels_tensor[:prompt_length] = -100  # prompt 부분의 손실은 계산하지 않음

        return {
            'input_ids': input_token_ids,
            'attention_mask': attn_mask,
            'labels': labels_tensor,
            'is_query_answerable': torch.tensor(1.0 if item['is_answerable'] else 0.0, dtype=torch.float)
        }

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_model_id', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--data_source', type=str, required=True)
    parser.add_argument('--save_path', type=str, default='./uncertainty_model')
    args = parser.parse_args()

    model_for_refusal = BoundaryAwareLM(args.base_model_id)
    token_processor = model_for_refusal.tokenizer

    training_dataset = QueryResponseDataset(args.data_source, token_processor)

    training_arguments = TrainingArguments(
        output_dir=args.save_path,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        learning_rate=1e-4,
        warmup_steps=100,
        logging_steps=10,
        save_strategy="epoch",
        bf16=True,
        remove_unused_columns=False
    )

    trainer_instance = Trainer(
        model=model_for_refusal,
        args=training_arguments,
        train_dataset=training_dataset,
        tokenizer=token_processor
    )

    trainer_instance.train()
    model_for_refusal.main_lm.save_pretrained(args.save_path)
    torch.save(model_for_refusal.uncertainty_head.state_dict(), f"{args.save_path}/uncertainty_head.pt")

if __name__ == "__main__":
    main()
1.1.3 합성 데이터 증강

실제 세계의 사실성 주석 데이터는 희소하고 비용이 많이 듭니다. 합성 데이터 증강은 GPT-4와 같은 교사 모델(teacher model)을 활용하여 고품질의 사실-환각 대조 쌍을 생성하고, 통제된 의미론적 교란을 통해 어려운 음성 샘플을 만듭니다. 주요 혁신은 불일치 설명 생성에 있습니다. 이는 참/거짓 레이블뿐만 아니라, 특정 진술이 거짓인 이유에 대한 자연어 추론을 모델이 생성하도록 요구합니다. 이러한 설명성 감독 신호는 모델의 사실 추론 능력을 크게 향상시킬 수 있습니다.

구현: 합성 환각 샘플을 사용한 대조 학습

이 스크립트는 GPT-4 API를 기반으로 하는 적대적 데이터 생성 파이프라인을 구현하며, 사실성 검증, 설명 생성 및 대조 손실 계산을 포함합니다.


import json
import openai
import torch
import argparse
from typing import List, Dict
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch.nn as nn
import os
import ast # For safer literal evaluation

class FalsityGenerator:
    def __init__(self, api_key: str, model_id: str = "gpt-4"):
        openai.api_key = api_key
        self.model = model_id
        self.similarity_encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

    def create_hallucinatory_variants(self, actual_fact: str, count: int = 3) -> List[Dict]:
        """
        주어진 실제 사실을 바탕으로 그럴듯하지만 거짓인 변형 생성
        """
        prompt = f"""다음은 사실 진술입니다: "{actual_fact}"
핵심 개체, 날짜 또는 관계를 변경하여 {count}개의 그럴듯하지만 거짓인 변형을 생성하세요.
각 변형에 대해:
1. 거짓 진술을 제공합니다.
2. 어떤 사실이 변경되었고 왜 틀렸는지 구체적으로 설명합니다.
3. 환각의 미묘함(1-5점, 5점은 매우 미묘함)을 평가합니다.

출력 형식: 'false_statement', 'explanation', 'subtlety_score' 키를 가진 JSON 리스트"""

        try:
            response = openai.ChatCompletion.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.8,
                max_tokens=1000
            )
            content_str = response.choices[0].message.content
            # 안전하게 JSON 문자열 평가 (예상치 못한 형식에 대비)
            if content_str.startswith("```json"):
                content_str = content_str[7:].strip().strip('`')
            variants = ast.literal_eval(content_str)
            return variants
        except Exception as e:
            print(f"생성 실패: {e}")
            return []

    def generate_corrective_explanation(self, correct_info: str, incorrect_statement: str) -> str:
        """
        특정 진술이 왜 거짓인지에 대한 상세 설명 생성
        """
        prompt = f"""실제 사실: {correct_info}
거짓 진술: {incorrect_statement}

거짓 진술이 왜 틀렸는지, 어떤 특정 정보가 그것과 모순되는지 인용하여 상세한 설명을 제공하세요. 100단어 이내로 작성하세요."""

        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3,
            max_tokens=200
        )
        return response.choices[0].message.content

    def filter_out_low_quality(self, true_fact: str, false_variant: str, similarity_threshold: float = 0.85) -> bool:
        """
        의미론적 유사성을 기반으로 낮은 품질의 음성 샘플 필터링 (너무 유사하거나 완전히 무관한 경우)
        """
        embeddings = self.similarity_encoder.encode([true_fact, false_variant])
        similarity = util.cos_sim(embeddings[0], embeddings[1]).item()

        # 유사도는 0.6-0.9 사이여야 함: 헷갈리게 할 만큼 유사하지만, 사실을 구별할 만큼 충분히 달라야 함
        return 0.6 < similarity < similarity_threshold

class DisinformationDetectionDataset(Dataset):
    def __init__(self, data_filepath: str, tokenizer, max_token_len: int = 1024):
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len
        with open(data_filepath, 'r', encoding='utf-8') as f:
            self.data_entries = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.data_entries)

    def __getitem__(self, idx):
        item = self.data_entries[idx]

        # 앵커(질문), 긍정(사실 + 설명), 부정(환각 + 설명) 삼중항 구성
        anchor_text = item['question']
        positive_text = f"사실: {item['true_fact']}\n설명: {item['true_explanation']}"
        negative_text = f"진술: {item['false_variant']}\n교정: {item['false_explanation']}"

        # 각각 인코딩
        anchor_encoded = self.tokenizer(
            anchor_text,
            max_length=self.max_token_len // 3,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        pos_encoded = self.tokenizer(
            positive_text,
            max_length=self.max_token_len // 3,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        neg_encoded = self.tokenizer(
            negative_text,
            max_length=self.max_token_len // 3,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'anchor_input_ids': anchor_encoded['input_ids'].squeeze(),
            'anchor_attention_mask': anchor_encoded['attention_mask'].squeeze(),
            'positive_input_ids': pos_encoded['input_ids'].squeeze(),
            'positive_attention_mask': pos_encoded['attention_mask'].squeeze(),
            'negative_input_ids': neg_encoded['input_ids'].squeeze(),
            'negative_attention_mask': neg_encoded['attention_mask'].squeeze(),
            'complexity_score': torch.tensor(item.get('subtlety_score', 3.0) / 5.0)  # 정규화
        }

class SemanticComparisonEncoder(nn.Module):
    def __init__(self, base_lm_name: str):
        super().__init__()
        self.encoder_lm = AutoModelForCausalLM.from_pretrained(
            base_lm_name,
            torch_dtype=torch.bfloat16,
            output_hidden_states=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_lm_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 히든 상태를 비교 공간으로 매핑하는 투영 헤드
        self.projection_head = nn.Sequential(
            nn.Linear(self.encoder_lm.config.hidden_size, 1024),
            nn.LayerNorm(1024),
            nn.Tanh()
        )

    def generate_embedding(self, input_ids, attention_mask):
        outputs = self.encoder_lm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        # 마지막 토큰의 히든 상태를 문장 표현으로 사용
        last_layer_hidden = outputs.hidden_states[-1]
        # 실제 마지막 비-패딩 토큰을 찾음
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = input_ids.size(0)
        embeddings = last_layer_hidden[range(batch_size), sequence_lengths]
        return self.projection_head(embeddings)

    def forward(self, anchor_input_ids, anchor_attention_mask,
                positive_input_ids, positive_attention_mask,
                negative_input_ids, negative_attention_mask, complexity_score=None):

        anchor_embed = self.generate_embedding(anchor_input_ids, anchor_attention_mask)
        positive_embed = self.generate_embedding(positive_input_ids, positive_attention_mask)
        negative_embed = self.generate_embedding(negative_input_ids, negative_attention_mask)

        # InfoNCE 손실 / 난이도 가중치가 있는 Triplet 손실
        # 난이도가 높은 음성 샘플(complexity_score가 높을수록)에 더 높은 가중치 부여
        pos_similarity = torch.cosine_similarity(anchor_embed, positive_embed, dim=1)
        neg_similarity = torch.cosine_similarity(anchor_embed, negative_embed, dim=1)

        # 온도 스케일링
        temp_scale = 0.05
        pos_similarity = pos_similarity / temp_scale
        neg_similarity = neg_similarity / temp_scale

        # 대조 손실: 긍정 샘플은 가깝게, 부정 샘플은 멀게
        all_logits = torch.stack([pos_similarity, neg_similarity], dim=1)  # [batch, 2]
        target_labels = torch.zeros(anchor_embed.size(0), dtype=torch.long, device=anchor_embed.device)

        loss_function = nn.CrossEntropyLoss()
        computed_loss = loss_function(all_logits, target_labels)

        # 어려운 음성 샘플 마이닝: 높은 complexity_score를 가진 샘플에 더 높은 손실 가중치 부여
        if complexity_score is not None:
            weights = 1.0 + complexity_score  # [1.0, 2.0] 범위
            computed_loss = (computed_loss * weights).mean()

        return {
            'loss': computed_loss,
            'anchor_embedding': anchor_embed,
            'positive_embedding': positive_embed,
            'negative_embedding': negative_embed
        }

def orchestrate_contrastive_data_generation(arguments):
    """대조 데이터셋 생성"""
    generator = FalsityGenerator(api_key=os.getenv("OPENAI_API_KEY"))

    with open(arguments.input, 'r', encoding='utf-8') as fin, open(arguments.output, 'w', encoding='utf-8') as fout:
        for line_data in fin:
            data_item = json.loads(line_data)
            fact_statement = data_item['fact']
            inquiry = data_item['question']

            # 환각 변형 생성
            variations = generator.create_hallucinatory_variants(fact_statement, count=2)

            for variant in variations:
                if not generator.filter_out_low_quality(fact_statement, variant['false_statement']):
                    continue

                true_explanation = generator.generate_corrective_explanation(fact_statement, variant['false_statement'])
                false_explanation = variant['explanation']

                output_record = {
                    'question': inquiry,
                    'true_fact': fact_statement,
                    'true_explanation': true_explanation,
                    'false_variant': variant['false_statement'],
                    'false_explanation': false_explanation,
                    'subtlety_score': variant['subtlety_score']
                }
                fout.write(json.dumps(output_record, ensure_ascii=False) + '\n')

def conduct_contrastive_training(arguments):
    """대조 모델 훈련"""
    tokenizer = AutoTokenizer.from_pretrained(arguments.base_lm_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset_for_training = DisinformationDetectionDataset(arguments.data, tokenizer)

    model_for_contrast = SemanticComparisonEncoder(arguments.base_lm_name)

    training_config = TrainingArguments(
        output_dir=arguments.output_dir,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=2,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        logging_steps=10,
        save_strategy="epoch",
        bf16=True,
        remove_unused_columns=False
    )

    trainer_instance = Trainer(
        model=model_for_contrast,
        args=training_config,
        train_dataset=dataset_for_training,
        tokenizer=tokenizer
    )

    trainer_instance.train()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--action', choices=['generate_data', 'train_model'], required=True)
    parser.add_argument('--input_file', type=str, help='생성할 사실 입력 파일')
    parser.add_argument('--output_file', type=str, help='생성된 데이터의 출력 경로')
    parser.add_argument('--data_to_train', type=str, help='대조 학습용 훈련 데이터')
    parser.add_argument('--base_lm_name', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--output_directory', type=str, default='./contrastive_learning_output')
    args = parser.parse_args()

    if args.action == 'generate_data':
        orchestrate_contrastive_data_generation(args)
    else:
        conduct_contrastive_training(args)

if __name__ == "__main__":
    main()

1.2 인간 피드백 기반 강화 학습(RLHF) 변형

표준 RLHF는 인간 선호도의 전반적인 분포를 최적화하며, 종종 사실 정확성을 희생하면서 언어 유창성을 지나치게 강조합니다. 사실성 강화 RLHF 변형은 보상 함수를 재구성하여 사실 검증 신호를 명시적으로 도입합니다. 이는 보상 모델이 텍스트 품질을 평가할 수 있을 뿐만 아니라 사실 오류를 식별하는 능력도 갖춰야 함을 의미합니다. 이는 일반적으로 자연어 추론(NLI) 모델의 출력 또는 외부 지식 기반 검색 결과를 통합하여 달성됩니다.

1.2.1 사실성 보상 모델

사실성 보상 모델(Factuality RM)은 표준 선호도 모델링을 기반으로 사실 일관성 검증 분기를 추가합니다. 이 분기는 사전 훈련된 NLI 모델(예: RoBERTa-NLI)을 사용하여 생성된 텍스트와 원본 문서 또는 지식 기반 간의 함의 관계를 계산합니다. 보상 점수는 유창성, 관련성 및 사실성 점수의 가중치 합계이며, 훈련 후반에 사실성 가중치가 점진적으로 증가하여 커리큘럼 학습 효과를 형성하고, 모델이 초기에 단일 차원만 과도하게 최적화하는 것을 방지합니다.

구현: NLI 기반 사실성 보상 모델

이 스크립트는 정책 경사(policy gradient)와 NLI 검증을 결합한 강화 학습 훈련 프로세스를 구현하며, PPO(Proximal Policy Optimization) 알고리즘 구현 및 사실성 보상 계산을 포함합니다.


import torch
import torch.nn as nn
import argparse
from transformers import (
    AutoModelForSequenceClassification, AutoTokenizer,
    AutoModelForCausalLM
)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import load_dataset
from typing import List, Dict
import numpy as np
from sentence_transformers import CrossEncoder

class TruthMetricModel(nn.Module):
    def __init__(self, base_preference_model: str, nli_evaluator: str = "cross-encoder/nli-deberta-v3-base"):
        super().__init__()
        # 기본 선호도 예측 모델
        self.preference_predictor = AutoModelForSequenceClassification.from_pretrained(
            base_preference_model, num_labels=1
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_preference_model)

        # 사실 검증을 위한 NLI 모델
        self.nli_verifier = CrossEncoder(nli_evaluator)

        # 학습 가능한 혼합 가중치
        self.truth_weight = nn.Parameter(torch.tensor(0.5))

    def calculate_nli_score(self, generated_texts: List[str],
                            reference_texts: List[str]) -> torch.Tensor:
        """NLI를 사용하여 생성된 텍스트와 참조 간의 사실 일관성 계산"""
        if not reference_texts or not generated_texts or len(generated_texts) != len(reference_texts):
            return torch.zeros(len(generated_texts), device=self.truth_weight.device)

        # NLI 입력 쌍 구성: (전제=참조, 가설=생성)
        pairs = [[ref, gen] for ref, gen in zip(reference_texts, generated_texts)]

        # NLI 예측: entailment (0), contradiction (2), neutral (1)
        nli_raw_scores = self.nli_verifier.predict(pairs, convert_to_numpy=True)

        # entailment 확률을 사실성 점수로 사용
        # nli_raw_scores 형태: [batch, 3]
        entailment_probabilities = torch.softmax(torch.tensor(nli_raw_scores, device=self.truth_weight.device), dim=1)[:, 0]
        return entailment_probabilities

    def forward(self, input_ids, attention_mask, generated_responses=None,
                reference_contexts=None, target_labels=None):
        # 표준 선호도 점수
        preference_logits = self.preference_predictor(input_ids=input_ids,
                                           attention_mask=attention_mask).logits.squeeze(-1)

        final_reward = preference_logits

        # 참조 텍스트가 제공되면 사실성 보상 추가
        if generated_responses is not None and reference_contexts is not None:
            factuality_scores = self.calculate_nli_score(generated_responses, reference_contexts)
            factuality_scores = factuality_scores.to(final_reward.device)

            # 동적 가중치: 시그모이드를 사용하여 (0,1) 사이로 제한
            dynamic_weight = torch.sigmoid(self.truth_weight)
            final_reward = (1 - dynamic_weight) * final_reward + dynamic_weight * factuality_scores * 5.0  # 스케일링 팩터 5.0으로 스케일 균형

        if target_labels is not None:
            loss_func = nn.MSELoss()
            combined_loss = loss_func(final_reward, target_labels.float())
            return {'loss': combined_loss, 'reward': final_reward}

        return {'reward': final_reward}

class PolicyRefinementAgent:
    def __init__(self, policy_model_name: str, reward_model_path: str):
        # 정책 모델 (가치 헤드 포함)
        self.policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(policy_model_name)
        self.reference_model = AutoModelForCausalLMWithValueHead.from_pretrained(policy_model_name)

        # 참조 모델 동결
        for param in self.reference_model.parameters():
            param.requires_grad = False

        self.tokenizer = AutoTokenizer.from_pretrained(policy_model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 보상 모델 로드
        self.reward_evaluator = TruthMetricModel(reward_model_path)
        self.reward_evaluator.eval() # 평가 모드

        # PPO 설정
        self.ppo_config_settings = PPOConfig(
            model_name=policy_model_name,
            learning_rate=1e-5,
            batch_size=4,
            mini_batch_size=1,
            gradient_accumulation_steps=4,
            optimize_cuda_cache=True,
        )

        self.ppo_trainer = PPOTrainer(
            config=self.ppo_config_settings,
            model=self.policy_model,
            ref_model=self.reference_model,
            tokenizer=self.tokenizer
        )

    def perform_training_step(self, prompts: List[str], responses: List[str],
                              ground_truth_references: List[str]):
        """단일 PPO 훈련 단계"""
        # 프롬프트 및 응답 인코딩
        prompt_tensors = [self.tokenizer.encode(p, return_tensors="pt",
                                              truncation=True, max_length=512).squeeze(0)
                        for p in prompts]
        response_tensors = [self.tokenizer.encode(r, return_tensors="pt",
                                                 truncation=True, max_length=512).squeeze(0)
                           for r in responses]

        # 보상 계산
        full_generated_texts = [self.tokenizer.decode(p_ids) + self.tokenizer.decode(r_ids)
                               for p_ids, r_ids in zip(prompt_tensors, response_tensors)]
        calculated_rewards = []

        with torch.no_grad():
            for text_gen, ref_gt in zip(full_generated_texts, ground_truth_references):
                encoded_input = self.tokenizer(text_gen, return_tensors="pt", truncation=True, max_length=1024)
                reward_output = self.reward_evaluator(
                    input_ids=encoded_input['input_ids'].to(self.reward_evaluator.preference_predictor.device),
                    attention_mask=encoded_input['attention_mask'].to(self.reward_evaluator.preference_predictor.device),
                    generated_responses=[text_gen],
                    reference_contexts=[ref_gt]
                )
                calculated_rewards.append(reward_output['reward'].item())

        rewards_tensor_list = [torch.tensor(r_val) for r_val in calculated_rewards]

        # PPO 업데이트
        training_stats = self.ppo_trainer.step(prompt_tensors, response_tensors, rewards_tensor_list)
        return training_stats

def train_truth_reward_model(arguments):
    """사실성 보상 모델 훈련"""
    reward_model_instance = TruthMetricModel(arguments.base_model_id)
    token_processor = reward_model_instance.tokenizer

    # 선호도 및 사실성 주석 데이터 로드
    # 데이터 형식: prompt, chosen_response, rejected_response, reference_text, factuality_label
    # 예시 데이터셋은 "json"으로 가정, 실제 사용 시에는 사용자 정의 데이터 로더 필요
    synthetic_dataset = load_dataset('json', data_files=arguments.data_file)['train']

    def process_data_for_reward_model(examples):
        # chosen 및 rejected에 대해 각각 보상 계산 및 간극 최대화
        chosen_encoded = token_processor(examples['chosen_response'], truncation=True,
                              max_length=512, padding='max_length')
        rejected_encoded = token_processor(examples['rejected_response'], truncation=True,
                                max_length=512, padding='max_length')

        return {
            'chosen_input_ids': chosen_encoded['input_ids'],
            'chosen_attention_mask': chosen_encoded['attention_mask'],
            'rejected_input_ids': rejected_encoded['input_ids'],
            'rejected_attention_mask': rejected_encoded['attention_mask'],
            'reference_info': examples['reference_text'],
            'chosen_text_content': examples['chosen_response'],
            'rejected_text_content': examples['rejected_response']
        }

    processed_dataset = synthetic_dataset.map(process_data_for_reward_model, batched=True)

    # 선호도 손실과 사실성 손실을 결합한 사용자 정의 훈련 루프
    optimizer = torch.optim.AdamW(reward_model_instance.parameters(), lr=2e-5)
    device = reward_model_instance.truth_weight.device

    for epoch_num in range(3):
        for batch_data in processed_dataset:
            # chosen 및 rejected 응답에 대한 보상 계산
            chosen_output = reward_model_instance(
                torch.tensor(batch_data['chosen_input_ids']).unsqueeze(0).to(device),
                torch.tensor(batch_data['chosen_attention_mask']).unsqueeze(0).to(device),
                generated_responses=[batch_data['chosen_text_content']],
                reference_contexts=[batch_data['reference_info']]
            )

            rejected_output = reward_model_instance(
                torch.tensor(batch_data['rejected_input_ids']).unsqueeze(0).to(device),
                torch.tensor(batch_data['rejected_attention_mask']).unsqueeze(0).to(device),
                generated_responses=[batch_data['rejected_text_content']],
                reference_contexts=[batch_data['reference_info']]
            )

            # Bradley-Terry 선호도 손실 + 사실성 제약
            preference_loss = -torch.log(torch.sigmoid(chosen_output['reward'] - rejected_output['reward']))

            # rejected가 환각인 경우 chosen이 rejected보다 더 사실적이어야 함을 보장
            fact_consistency_constraint = torch.relu(0.5 - (chosen_output['reward'] - rejected_output['reward']))

            combined_loss = preference_loss + 0.1 * fact_consistency_constraint
            combined_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
    reward_model_instance.preference_predictor.save_pretrained(arguments.output_dir) # Save the base part
    torch.save(reward_model_instance.truth_weight.state_dict(), f"{arguments.output_dir}/truth_weight.pt") # Save the custom part
    reward_model_instance.tokenizer.save_pretrained(arguments.output_dir)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--action', choices=['train_reward_model', 'run_ppo'], required=True)
    parser.add_argument('--base_model_id', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--data_file', type=str)
    parser.add_argument('--output_directory', type=str, required=True)
    parser.add_argument('--reward_model_dir', type=str)
    args = parser.parse_args()

    if args.action == 'train_reward_model':
        train_truth_reward_model(args)
    else:
        # PPO 훈련
        if not args.reward_model_dir:
            raise ValueError("PPO 훈련을 위해 reward_model_dir이 필요합니다.")
        ppo_agent = PolicyRefinementAgent(args.base_model_id, args.reward_model_dir)
        # 쿼리 데이터셋 로드 및 훈련 ...
        print("PPO 훈련 구현을 위해서는 쿼리 데이터셋 설정이 필요합니다.")
        # 예시:
        # queries = ["프랑스의 수도는?", "지구에서 가장 높은 산은?"]
        # responses = ["파리입니다.", "에베레스트 산입니다."]
        # references = ["프랑스의 수도는 파리이다.", "에베레스트 산은 세계에서 가장 높은 산이다."]
        # stats = ppo_agent.perform_training_step(queries, responses, references)
        # print(stats)

if __name__ == "__main__":
    main()
1.2.2 직접 선호도 최적화(DPO)의 사실성 적응

DPO는 명시적인 보상 모델링을 피하기 위해 정책 모델과 참조 모델 간의 로그 확률 비율을 직접 최적화합니다. 그러나 사실성 시나리오에서는 선호도 데이터에 명확한 사실성 주석이 부족하다는 문제에 직면합니다. SimPO(Simple Preference Optimization) 적응은 사실 일관성 기반 샘플링 전략을 도입하여 배치 내에서 어려운 음성 샘플(즉, 긍정 샘플과 의미론적으로 유사하지만 사실 오류를 포함하는 샘플)을 동적으로 구성함으로써, 미묘한 사실 차이를 판별하는 모델의 능력을 강화합니다.

1.2.3 다단계 약한 감독

다단계 약한 감독 프레임워크는 여러 신뢰도 신호 소스(예: 검색 엔진에서 반환된 요약, 위키백과 엔티티 연결 확률, 도메인 전문가의 일관성 투표)를 통합하고, 노이즈 인식 레이블 집계 알고리즘을 통해 훈련에 사용되는 소프트 레이블을 생성합니다. 이 방법의 핵심은 신뢰도 가중치 부여에 있습니다. 각 약한 감독 소스는 과거 정확도에 따라 동적으로 가중치를 조정하며, 최종 사실성 손실 함수는 단일 이진 하드 레이블에 의존하기보다는 이러한 소프트 레이블에 대한 최대 우도 추정치를 적용합니다.

2. 모델 편집 및 지식 업데이트 메커니즘

사전 훈련된 모델의 파라미터 공간에는 방대한 사실 지식이 저장되어 있지만, 정적 가중치는 지식의 동적 진화에 적응하기 어렵습니다. 모델 편집 기술은 특정 지식의 파라미터 표현을 국소적으로 수정하여 전체 미세 조정 없이 지식을 업데이트할 수 있도록 합니다. 이 분야의 주요 과제는 '위치-수정'의 정확성과 부작용 제어입니다. 즉, 특정 사실을 저장하는 뉴런의 하위 집합을 정확히 찾아야 할 뿐만 아니라, 편집 작업이 관련 없는 지식 영역으로 확산되는 것을 방지해야 합니다.

2.1 위치-수정(Locate-then-Edit) 패러다임

위치-수정 패러다임은 지식 업데이트를 두 가지 분리된 단계로 나눕니다. 위치 단계에서는 귀인 분석 방법을 사용하여 특정 사실 예측에 가장 크게 기여하는 뉴런 또는 레이어 영역을 식별합니다. 수정 단계에서는 이렇게 위치된 파라미터에 대해 정밀한 수정을 수행합니다. 이 방법의 장점은 편집의 국소성과 설명 가능성으로, 시스템이 신경망 내에서 특정 지식의 물리적 저장 위치를 추적할 수 있도록 합니다.

2.1.1 지식 뉴런 식별

지식 뉴런 개념은 트랜스포머의 피드포워드 네트워크(FFN) 레이어 분석에서 비롯됩니다. 연구에 따르면 FFN의 특정 뉴런 활성화 패턴은 구체적인 사실(예: "파리-프랑스 수도")과 관련이 있습니다. 통합 경사(Integrated Gradients) 또는 활성화 패칭(Activation Patching) 기술을 통해 다양한 뉴런이 특정 사실 예측에 기여하는 인과적 정도를 정량화할 수 있으며, 이를 통해 지식-뉴런 매핑 그래프를 구축할 수 있습니다.

구현: 귀인 기반 지식 뉴런 위치 파악

이 스크립트는 통합 경사와 활성화 패칭 방법을 구현하여 특정 사실을 저장하는 뉴런을 식별하고, 후속 편집을 위한 목표 위치를 제공합니다.


import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import argparse
import json

class FactPathwayAnalyzer:
    def __init__(self, model_identifier: str, execution_device: str = "cuda"):
        self.device = execution_device
        self.model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            torch_dtype=torch.float16,
            device_map="auto",
            output_attentions=False
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_identifier)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 후크 핸들 및 순방향 전달 캐시 저장
        self.active_hooks = []
        self.cached_activations = {}
        self.captured_gradients = {}

    def attach_hooks(self, target_layer_indices: List[int]):
        """FFN 레이어 활성화를 캡처하기 위한 순방향 및 역방향 후크 등록"""
        def capture_activation(name):
            def hook(module, input_data, output_data):
                self.cached_activations[name] = output_data.detach()
            return hook

        def capture_gradient(name):
            def hook(module, grad_input, grad_output):
                self.captured_gradients[name] = grad_output[0].detach()
            return hook

        # Llama 아키텍처 가정
        model_layers = self.model.model.layers

        for layer_idx in target_layer_indices:
            layer_module = model_layers[layer_idx]
            # FFN 중간 활성화 캡처 (up-projection 이후, down-projection 이전)
            fwd_handle = layer_module.mlp.up_proj.register_forward_hook(
                capture_activation(f"layer_{layer_idx}_ffn_activation")
            )
            bwd_handle = layer_module.mlp.up_proj.register_full_backward_hook(
                capture_gradient(f"layer_{layer_idx}_ffn_gradient")
            )
            self.active_hooks.extend([fwd_handle, bwd_handle])

    def detach_hooks(self):
        for h in self.active_hooks:
            h.remove()
        self.active_hooks = []
        self.cached_activations = {}
        self.captured_gradients = {}

    def contribution_gradients(self, input_text: str, target_token_str: str,
                               interpolation_steps: int = 20) -> Dict[str, torch.Tensor]:
        """
        주요 뉴런 식별을 위한 통합 경사 계산
        입력 임베딩을 따라 보간하고 대상 출력에 대한 각 뉴런의 기여도를 계산
        """
        # 입력 및 대상 인코딩
        inputs_encoded = self.tokenizer(input_text, return_tensors="pt").to(self.device)
        target_token_id = self.tokenizer.encode(target_token_str, add_special_tokens=False)[0]

        # 입력 임베딩 가져오기
        input_embeddings = self.model.model.embed_tokens(inputs_encoded['input_ids'])

        # 각 경로 점에서의 경사 저장
        all_path_gradients = defaultdict(list)

        # 기준선(영 벡터)에서 실제 입력까지의 경로 적분
        for alpha_val in np.linspace(0, 1, interpolation_steps):
            scaled_embeddings = input_embeddings * alpha_val
            scaled_embeddings.requires_grad_(True)

            # 순방향 전달
            model_outputs = self.model(inputs_embeds=scaled_embeddings)
            output_logits = model_outputs.logits

            # 대상 토큰의 로짓 가져오기
            target_output_logit = output_logits[0, -1, target_token_id]

            # 역방향 전달
            self.model.zero_grad()
            target_output_logit.backward(retain_graph=True)

            # 경사 수집
            for name, grad_val in self.captured_gradients.items():
                all_path_gradients[name].append(grad_val.clone())

            self.captured_gradients = {} # 매 스텝마다 초기화

        # 통합 경사 계산: (x - x_baseline) * mean(gradients)
        average_gradients = {name: torch.stack(grads).mean(dim=0)
                             for name, grads in all_path_gradients.items()}

        # 실제 활성화 값 곱하기 (Riemann 근사)
        integrated_gradients_output = {}
        for name_grad in average_gradients:
            activation_name = name_grad.replace("_gradient", "_activation")
            if activation_name in self.cached_activations:
                integrated_gradients_output[name_grad] = self.cached_activations[activation_name] * average_gradients[name_grad]

        return integrated_gradients_output

    def intervention_impact_analysis(self, input_text: str, target_token_str: str,
                                     perturbation_strategy: str = "noise") -> Dict[str, float]:
        """
        활성화 패칭: 특정 뉴런을 손상시켜 출력 변화를 관찰하고 인과 관계 확인
        perturbation_strategy: "noise"(노이즈 추가), "zero"(영으로 설정), "resample"(재샘플링)
        """
        # 먼저 클린 실행의 활성화 가져오기
        self.attach_hooks(list(range(self.model.config.num_hidden_layers)))

        inputs_encoded = self.tokenizer(input_text, return_tensors="pt").to(self.device)
        target_token_id = self.tokenizer.encode(target_token_str, add_special_tokens=False)[0]

        with torch.no_grad():
            clean_model_outputs = self.model(**inputs_encoded)
            clean_target_logit = clean_model_outputs.logits[0, -1, target_token_id].item()

        clean_layer_activations = {k: v.clone() for k, v in self.cached_activations.items()}
        self.detach_hooks()

        # 레이어별 뉴런별로 개입 테스트 수행
        causal_influence_scores = {}

        for layer_index in range(self.model.config.num_hidden_layers):
            activation_key = f"layer_{layer_index}_ffn_activation"
            if activation_key not in clean_layer_activations:
                continue

            activation_dims = clean_layer_activations[activation_key].shape
            num_neurons_in_layer = activation_dims[-1]

            # 핵심 뉴런 샘플링 (전체 대신 계산량 절약을 위해)
            sample_indices = np.random.choice(num_neurons_in_layer, min(100, num_neurons_in_layer), replace=False)

            for neuron_index in sample_indices:
                # 개입을 위한 후크 재등록
                def create_perturbation_hook(current_layer_name, current_neuron_idx, original_act_data):
                    def hook_func(module, input_tuple, output_tensor):
                        # 특정 뉴런 패칭
                        modified_output = output_tensor.clone()
                        # 활성화를 억제 (예: 10%로 스케일링)
                        modified_output[:, :, current_neuron_idx] = original_act_data[:, :, current_neuron_idx] * 0.1
                        return modified_output
                    return hook_func

                hook_handle = self.model.model.layers[layer_index].mlp.up_proj.register_forward_hook(
                    create_perturbation_hook(activation_key, neuron_index, clean_layer_activations[activation_key])
                )

                with torch.no_grad():
                    perturbed_model_outputs = self.model(**inputs_encoded)
                    perturbed_target_logit = perturbed_model_outputs.logits[0, -1, target_token_id].item()

                hook_handle.remove()

                # 인과 효과 계산: 클린과 패칭의 차이
                causal_effect_value = clean_target_logit - perturbed_target_logit
                causal_influence_scores[f"{layer_index}_{neuron_index}"] = causal_effect_value

        return causal_influence_scores

    def pinpoint_knowledge_neurons(self, subject: str, predicate: str,
                                   object_value: str, top_k_neurons: int = 100) -> List[Dict]:
        """
        귀인 분석 방법을 통합하여 지식 뉴런 식별
        통합 경사와 활성화 패칭 결과 결합
        """
        query_prompt = f"{subject}은 {predicate}인"
        target_token_string = f" {object_value}" # 토큰과 일치하도록 공백 주의

        # 방법 1: 통합 경사
        self.attach_hooks(list(range(5, 25))) # 중간 레이어가 일반적으로 지식 저장
        ig_scores = self.contribution_gradients(query_prompt, target_token_string)
        self.detach_hooks()

        # IG 점수 집계 (뉴런별 최대 절대값)
        ig_importance_mapping = {}
        for layer_activation_name, scores_tensor in ig_scores.items():
            layer_id = int(layer_activation_name.split("_")[1])
            # scores_tensor: [batch, seq_len, neurons]
            neuron_level_scores = scores_tensor.abs().max(dim=1)[0].squeeze() # 시퀀스 차원에서의 최대 영향
            for neuron_idx, score_val in enumerate(neuron_level_scores):
                ig_importance_mapping[f"{layer_id}_{neuron_idx}"] = score_val.item()

        # 방법 2: 활성화 패칭
        patching_scores = self.intervention_impact_analysis(query_prompt, target_token_string)

        # 종합 점수 (정규화 후 합산)
        max_ig_score = max(ig_importance_mapping.values()) if ig_importance_mapping else 1
        max_patch_score = max(abs(v) for v in patching_scores.values()) if patching_scores else 1

        combined_scoring_results = {}
        all_identified_neurons = set(ig_importance_mapping.keys()) | set(patching_scores.keys())

        for neuron_key in all_identified_neurons:
            ig_normalized = ig_importance_mapping.get(neuron_key, 0) / max_ig_score
            patch_normalized = abs(patching_scores.get(neuron_key, 0)) / max_patch_score
            combined_scoring_results[neuron_key] = 0.5 * ig_normalized + 0.5 * patch_normalized

        # 상위 k개 정렬 및 반환
        sorted_neuron_results = sorted(combined_scoring_results.items(), key=lambda x: x[1], reverse=True)

        final_ranked_neurons = []
        for neuron_id_str, score_val in sorted_neuron_results[:top_k_neurons]:
            layer_str, index_str = neuron_id_str.split("_")
            final_ranked_neurons.append({
                "layer_index": int(layer_str),
                "neuron_identifier": int(index_str),
                "combined_importance_score": float(score_val),
                "integrated_gradient_score": float(ig_importance_mapping.get(neuron_id_str, 0)),
                "activation_patching_score": float(patching_scores.get(neuron_id_str, 0))
            })

        return final_ranked_neurons

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--target_subject', type=str, required=True)
    parser.add_argument('--target_relation', type=str, required=True)
    parser.add_argument('--target_object', type=str, required=True)
    parser.add_argument('--output_json_path', type=str, default='knowledge_neuron_analysis.json')
    args = parser.parse_args()

    locator = FactPathwayAnalyzer(args.model_name_or_path)
    identified_neurons = locator.pinpoint_knowledge_neurons(args.target_subject, args.target_relation, args.target_object)

    with open(args.output_json_path, 'w', encoding='utf-8') as f:
        json.dump({
            "fact_analyzed": f"{args.target_subject} {args.target_relation} {args.target_object}",
            "identified_knowledge_neurons": identified_neurons
        }, f, indent=2, ensure_ascii=False)

    print(f"식별된 지식 뉴런 수: {len(identified_neurons)}. 결과는 {args.output_json_path}에 저장되었습니다.")

if __name__ == "__main__":
    main()
2.1.2 ROME 및 MEMIT

랭크-원 모델 편집(Rank-One Model Editing, ROME) 및 대규모 모델 편집(Mass Editing Memory in a Transformer, MEMIT)은 Locate-then-Edit 패러다임의 구체적인 구현입니다. ROME은 랭크-원 행렬 업데이트를 통해 FFN 레이어의 키-값 투영을 수정하여 새로운 지식을 점진적인 저랭크 업데이트 행렬로 인코딩합니다. MEMIT은 ROME을 확장하여 배치 지식 편집을 지원하며, 여러 편집 작업 간의 호환성을 보장하기 위해 제약 최적화를 통해 지식 충돌을 방지합니다. 두 가지 모두 위치 단계에서 식별된 핵심 레이어(일반적으로 중간 레이어)에 의존하며, 이 레이어에 랭크 제약이 있는 가중치 업데이트를 주입합니다.

구현: ROME 및 MEMIT을 이용한 배치 편집

이 스크립트는 ROME의 단일 편집과 MEMIT의 배치 편집 알고리즘을 구현하며, 랭크-원 업데이트 계산과 레이어별 개입 로직을 포함합니다.


import torch
import torch.nn as nn
import numpy as np
import json
import argparse
from typing import List, Dict, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import defaultdict

class SingleFactModifier:
    def __init__(self, model_id: str):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,  # 편집에는 높은 정밀도 필요
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def collect_input_distribution_stats(self, layer_idx: int, sample_count: int = 10000) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        ROME의 사전 조건 행렬에 사용되는 특정 FFN 레이어 입력의 공분산 행렬 계산
        이는 ROME의 핵심: 훈련 분포의 통계 정보를 사용하여 업데이트를 제약
        """
        print(f"레이어 {layer_idx}에 대한 공분산 계산 중...")

        target_layer_mlp = self.model.model.layers[layer_idx].mlp

        # 활성화 통계 수집
        all_activations = []

        def capture_activation_hook(module, input_tuple, output_tensor):
            # input_tuple은 일반적으로 튜플이며, 첫 번째 요소는 [batch, seq, hidden]
            all_activations.append(input_tuple[0].detach().cpu().reshape(-1, input_tuple[0].shape[-1]))

        hook_handle = target_layer_mlp.register_forward_hook(capture_activation_hook)

        # 위키피디아 유사 텍스트 샘플 사용 (실제 앱에서는 대표적인 코퍼스 사용)
        illustrative_texts = [
            "프랑스의 수도는 파리입니다.",
            "그 책의 작가는 유명합니다.",
            "머신러닝은 인공지능의 하위 분야입니다.",
            "뉴욕은 미국의 가장 큰 도시 중 하나입니다.",
            "태양계에는 여덟 개의 행성이 있습니다.",
        ] * (sample_count // len(illustrative_texts) + 1)

        with torch.no_grad():
            for text_sample in illustrative_texts[:sample_count]:
                input_tokens = self.tokenizer(text_sample, return_tensors="pt").to(self.model.device)
                self.model(**input_tokens)

        hook_handle.remove()

        # 모든 활성화 연결
        concatenated_activations = torch.cat(all_activations, dim=0).to(self.model.device)

        # 공분산 행렬 계산 C = E[hh^T]
        # 수치적으로 안정적인 구현 사용
        mean_vec = concatenated_activations.mean(dim=0, keepdim=True)
        centered_activations = concatenated_activations - mean_vec
        covariance_matrix = (centered_activations.T @ centered_activations) / len(centered_activations)

        # 역행렬을 보장하기 위한 정규화 항 추가
        covariance_matrix = covariance_matrix + 1e-4 * torch.eye(covariance_matrix.shape[0], device=covariance_matrix.device)

        return covariance_matrix, mean_vec.squeeze(0)

    def infer_edit_layer(self, query: str, new_target: str) -> int:
        """
        인과 매개 분석을 기반으로 핵심 레이어 추론 (단순화된 버전, 실제로는 더 복잡한 귀인 분석 필요)
        사실 지식은 일반적으로 중간 레이어에 저장 (예: Llama-2의 8-16 레이어)
        """
        # 여기서는 경험적인 중간 레이어를 사용
        total_layers = self.model.config.num_hidden_layers
        return total_layers // 2

    def apply_single_edit(self, fact_to_edit: Dict, target_layer: Optional[int] = None) -> Dict[str, torch.Tensor]:
        """
        랭크-원 모델 편집 수행
        fact_to_edit 형식: {"prompt": "프랑스의 수도는", "target_new": "런던", "target_old": "파리"}
        """
        edit_prompt = fact_to_edit['prompt']
        new_value = fact_to_edit['target_new']
        # old_value = fact_to_edit.get('target_old', None) # 현재 사용되지 않음

        if target_layer is None:
            target_layer = self.infer_edit_layer(edit_prompt, new_value)

        print(f"사실 편집: '{edit_prompt} -> {new_value}' (레이어 {target_layer})")

        # 1. 사전 조건 행렬 (공분산) 계산
        cov_matrix, input_mean = self.collect_input_distribution_stats(target_layer)
        inv_cov_matrix = torch.inverse(cov_matrix)

        # 2. 목표 키-값 쌍 계산
        # k*: 프롬프트의 표현 (검색에 사용)
        # v*: 새로운 목표 값의 표현

        input_tokens = self.tokenizer(edit_prompt, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            output_hidden = self.model(**input_tokens, output_hidden_states=True)
            # 편집 대상 레이어의 히든 상태 (FFN 입력)
            # Llama 아키텍처: up_proj는 FFN의 첫 부분이며, 그 입력은 이전 레이어의 히든 상태입니다.
            # ROME 논문의 구현은 FFN의 중간 활성화에 대해 이루어지므로, up_proj의 출력을 키로 사용
            
            # 중간 활성화 캡처를 위한 임시 후크
            captured_intermediate_activation = None
            def temp_capture_hook(module, input_data, output_data):
                nonlocal captured_intermediate_activation
                captured_intermediate_activation = output_data[0, -1, :].detach() # 배치 및 시퀀스의 마지막 토큰
            
            mlp_module = self.model.model.layers[target_layer].mlp
            hook_handle = mlp_module.up_proj.register_forward_hook(temp_capture_hook)
            self.model(**input_tokens) # 모델을 다시 실행하여 활성화 캡처
            hook_handle.remove()

            key_vector_k_star = captured_intermediate_activation  # [intermediate_dim]

        # 새 목표 인코딩
        new_target_ids = self.tokenizer(new_value, add_special_tokens=False, return_tensors="pt")['input_ids'].to(self.model.device)
        with torch.no_grad():
            new_target_embeddings = self.model.model.embed_tokens(new_target_ids)
            # 목표 표현으로 간단히 평균 사용 (실제 앱에서는 더 정교한 처리 필요)
            target_value_v_star = new_target_embeddings.mean(dim=1).squeeze()  # [hidden_dim]

        # 3. 랭크-원 업데이트 계산
        # ROME 핵심 공식: W_new = W_old + (v* - W_old @ k*) @ k*^T @ C_inv / (k*^T @ C_inv @ k*)

        # 현재 FFN의 down-projection 가중치 가져오기
        # Llama 아키텍처: up_proj (hidden->intermediate), down_proj (intermediate->hidden)
        # down_proj의 입력은 intermediate_dim, 출력은 hidden_dim
        
        old_weight_matrix_W = mlp_module.down_proj.weight  # [hidden_dim, intermediate_dim]

        # 현재 키(k_star)에 대한 현재 출력 값 (v_old) 계산
        current_value_output = old_weight_matrix_W @ key_vector_k_star  # [hidden_dim]

        # 잔차 벡터 계산
        residual_delta_v = target_value_v_star - current_value_output  # [hidden_dim]

        # 사전 조건이 적용된 키 벡터
        # C_inv는 intermediate_dim x intermediate_dim이어야 함
        # collect_input_distribution_stats에서 가져온 C는 mlp 입력에 대한 것이므로,
        # here, we need the covariance of the intermediate activations directly (output of up_proj)
        # For simplicity, and aligning with some ROME implementations, we will use inv_cov_matrix directly if it's the right shape
        # Or more accurately, re-compute for the intermediate space:
        
        # 실제 구현에서는 intermediate_dim에 대한 공분산 행렬이 필요함
        # 여기서는 simplified: up_proj의 출력 공간에 대한 단위 행렬 또는 간단한 근사를 사용
        intermediate_dim_size = key_vector_k_star.shape[0]
        C_intermediate_approx_inv = torch.eye(intermediate_dim_size, device=self.model.device) * 1e-4
        
        # 업데이트 항 계산
        # (residual_delta_v @ (k_star_intermediate @ C_intermediate_approx_inv).T) / (k_star_intermediate @ C_intermediate_approx_inv @ k_star_intermediate)
        
        numerator = torch.outer(residual_delta_v, (key_vector_k_star @ C_intermediate_approx_inv)) # [hidden_dim, intermediate_dim]
        denominator = (key_vector_k_star @ C_intermediate_approx_inv @ key_vector_k_star).item()

        # 업데이트 행렬
        update_matrix_W = numerator / (denominator + 1e-6)

        # 기존 가중치 저장 (복구용)
        original_weights = old_weight_matrix_W.clone()

        # 업데이트 적용
        with torch.no_grad():
            mlp_module.down_proj.weight.copy_(old_weight_matrix_W + update_matrix_W)

        return {
            "layer_index": target_layer,
            "original_weights_snapshot": original_weights,
            "updated_weights_value": mlp_module.down_proj.weight.clone(),
            "applied_delta": update_matrix_W
        }

    def revert_weights(self, edit_details: Dict):
        """원래 가중치 복구 (편집 취소용)"""
        layer_id = edit_details['layer_index']
        original_w = edit_details['original_weights_snapshot']
        self.model.model.layers[layer_id].mlp.down_proj.weight.copy_(original_w)

class BatchKnowledgeUpdater(SingleFactModifier):
    def execute_batch_updates(self, facts_list: List[Dict]):
        """
        대규모 모델 편집 (MEMIT)
        핵심 차이점: 여러 제약을 동시에 해결하여 충돌 방지
        """
        print(f"{len(facts_list)}개 사실에 대한 MEMIT 배치 편집 시작...")

        # 목표 레이어별로 사실 그룹화 (보통 동일 또는 인접 레이어 편집)
        grouped_facts_by_layer = defaultdict(list)
        for fact_item in facts_list:
            layer_to_edit = self.infer_edit_layer(fact_item['prompt'], fact_item['target_new'])
            grouped_facts_by_layer[layer_to_edit].append(fact_item)

        all_edit_records = []

        for layer_num, current_batch_facts in grouped_facts_by_layer.items():
            # 해당 레이어의 모든 k* 및 v* 수집
            key_vectors_K = []  # 키 행렬 [num_facts, intermediate_dim]
            current_value_outputs_V_curr = []  # 현재 값 [num_facts, hidden_dim]
            target_value_outputs_V_target = []   # 목표 값 [num_facts, hidden_dim]

            for fact_in_batch in current_batch_facts:
                input_tokens = self.tokenizer(fact_in_batch['prompt'], return_tensors="pt").to(self.model.device)

                # 중간 레이어 활성화 캡처
                captured_intermediate_activation = None
                def temp_capture_hook(module, input_data, output_data):
                    nonlocal captured_intermediate_activation
                    captured_intermediate_activation = output_data[0, -1, :].detach()
                
                mlp_module = self.model.model.layers[layer_num].mlp
                hook_handle = mlp_module.up_proj.register_forward_hook(temp_capture_hook)
                self.model(**input_tokens)
                hook_handle.remove()

                key_vector_K = captured_intermediate_activation
                key_vectors_K.append(key_vector_K)

                # 현재 출력 및 목표 출력 계산
                with torch.no_grad():
                    current_V = mlp_module.down_proj(key_vector_K.unsqueeze(0)).squeeze()
                    current_value_outputs_V_curr.append(current_V)

                    # 목표 표현 (간단화: 목표 토큰의 임베딩 사용)
                    target_ids = self.tokenizer(fact_in_batch['target_new'], return_tensors="pt")['input_ids'].to(self.model.device)
                    target_emb = self.model.model.embed_tokens(target_ids).mean(dim=1).squeeze()
                    target_value_outputs_V_target.append(target_emb)

            # 텐서로 스택
            K_matrix = torch.stack(key_vectors_K)  # [N, intermediate_dim]
            V_current_matrix = torch.stack(current_value_outputs_V_curr)  # [N, hidden_dim]
            V_target_matrix = torch.stack(target_value_outputs_V_target)  # [N, hidden_dim]

            # 배치 최적화: || W @ K^T - V_target^T ||_F + 정규화 항 최소화
            # 다른 입력에 대한 동작은 유지 (사전 조건 행렬 통해)

            old_weight_matrix_W = mlp_module.down_proj.weight  # [hidden_dim, intermediate_dim]

            # 공분산 행렬 (근사)
            # 실제 구현에서는 intermediate_dim에 대한 공분산 행렬이 필요함
            C_intermediate_approx = torch.eye(K_matrix.shape[1], device=K_matrix.device) * 0.1
            C_intermediate_approx_inv = torch.inverse(C_intermediate_approx)

            # 잔차 계산
            Delta_V_transpose = (V_target_matrix - V_current_matrix).T  # [hidden_dim, N]

            # 투영 행렬 계산
            KC_inv = K_matrix @ C_intermediate_approx_inv  # [N, intermediate_dim]
            M_matrix = KC_inv @ K_matrix.T    # [N, N]
            M_matrix_inv = torch.inverse(M_matrix + 1e-3 * torch.eye(M_matrix.shape[0], device=M_matrix.device))  # 정규화

            # 업데이트 계산
            W_update_matrix = Delta_V_transpose @ M_matrix_inv @ KC_inv  # [hidden_dim, intermediate_dim]

            # 업데이트 적용
            with torch.no_grad():
                new_weight_matrix = old_weight_matrix_W + W_update_matrix
                mlp_module.down_proj.weight.copy_(new_weight_matrix)

            all_edit_records.append({
                "edited_layer_index": layer_num,
                "number_of_edits_in_batch": len(current_batch_facts),
                "original_weights_for_layer": old_weight_matrix_W.clone(),
                "applied_weight_delta": W_update_matrix
            })

            print(f"레이어 {layer_num}: {len(current_batch_facts)}개 편집 적용 완료")

        return all_edit_records

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--operation_mode', choices=['rome_single', 'memit_batch'], required=True)
    parser.add_argument('--target_model', type=str, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--facts_json_file', type=str, required=True)
    parser.add_argument('--output_model_dir', type=str, default='./edited_llm_artifacts')
    args = parser.parse_args()

    # 사실 데이터 로드
    with open(args.facts_json_file, 'r', encoding='utf-8') as f:
        facts_data = json.load(f)
        if not isinstance(facts_data, list):
            facts_data = [facts_data]

    if args.operation_mode == 'rome_single':
        editor_instance = SingleFactModifier(args.target_model)
        if len(facts_data) > 1:
            print("경고: ROME은 단일 편집용입니다. 첫 번째 사실만 처리합니다.")
        edit_result = editor_instance.apply_single_edit(facts_data[0])
        print(f"단일 편집 완료 (레이어 {edit_result['layer_index']})")
    else: # memit_batch
        editor_instance = BatchKnowledgeUpdater(args.target_model)
        edit_result = editor_instance.execute_batch_updates(facts_data)
        print(f"배치 편집 완료: {len(edit_result)}개 레이어 수정됨")

    # 편집된 모델 저장
    editor_instance.model.save_pretrained(args.output_model_dir)
    editor_instance.tokenizer.save_pretrained(args.output_model_dir)
    print(f"모델이 {args.output_model_dir}에 저장되었습니다.")

if __name__ == "__main__":
    main()
2.1.3 지식 편집의 부작용 제어

지식 편집의 주요 위험은 '편집 전파(edit spreading)'와 '지식 충돌'입니다. 편집 전파는 특정 사실에 대한 수정이 관련은 있지만 다른 사실(예: "프랑스 수도"를 수정했지만 "프랑스 최대 도시"에는 영향을 주지 않아야 함)에 의도치 않게 영향을 미치는 것을 의미합니다. 부작용 제어 전략에는 국소성 제약(편집이 목표 프롬프트에만 영향을 미치도록 보장), 특이성 제약(과도한 일반화 방지), 그리고 인접 지식 유지(정규화 항을 통해 관련 없는 지식의 출력 유지 강제)가 포함됩니다. MEMIT은 여러 편집의 호환성을 동시에 제약하여 지식 충돌을 완화하며, 후속 연구에서 도입된 편집 위치 범위 제한(예: FFN 레이어의 특정 하위 공간만 수정)은 부작용을 더욱 억제합니다.

2.2 지속 학습 및 지식 고정

단일 편집은 지식 업데이트에 적합하지만, 모델은 새로운 지식을 지속적으로 흡수해야 합니다. 지속적인 사전 훈련은 재앙적 망각(catastrophic forgetting)과 지식 고정(knowledge solidification)의 딜레마에 직면합니다. 새로운 지식의 지속적인 주입은 이미 고정된 사실 표현을 방해할 수 있으며, 오래된 지식 유지에 대한 과도한 제약은 새로운 지식의 효과적인 인코딩을 방해할 수 있습니다.

2.2.1 사실성 커리큘럼 설계

사실성 커리큘럼 학습(Factuality Curriculum)은 사실의 신뢰도와 복잡성에 따라 훈련 순서를 구성합니다. 높은 신뢰도의 사실(예: 상식)은 안정적인 기반을 구축하기 위해 훈련 초기에 도입되며, 낮은 신뢰도의 추측성 지식(예: 새로운 과학 가설)은 후기에 더 낮은 학습률과 더 강력한 정규화를 적용하여 도입됩니다. 이러한 계층적 훈련 전략은 인간 교육의 나선형 커리큘럼을 모방하여 새로운 지식이 견고한 오래된 지식 기반 위에 구축되도록 보장하고 간섭을 줄입니다.

2.2.2 기억 증강 아키텍처

미분 가능한 외부 기억 모듈(Differentiable External Memory)은 파라미터화된 지식과 명시적 기억을 분리하여 지속 학습에서 발생하는 간섭 문제를 완화합니다. 모델 핵심 파라미터는 비교적 안정적으로 유지되며, 새로운 지식은 키-값 쌍 형태로 외부 기억 저장소에 저장되고, 주의 메커니즘을 통해 필요에 따라 검색됩니다. 이 아키텍처는 기본 파라미터를 수정하지 않고도 지식을 업데이트할 수 있어 지식의 점진적이고 설명 가능한 관리를 가능하게 합니다. 기억 쓰기 전략에는 최근 최소 사용(LRU) 교체, 중요성 기반의 희소 쓰기, 그리고 신구 지식의 키 공간 중첩을 피하기 위한 충돌 감지 메커니즘이 포함됩니다.

3. 디코딩 전략 및 추론 시 계산 최적화

훈련 시 개입은 모델의 파라미터 분포를 변경하지만, 생성 과정에서의 샘플링 전략 또한 중요합니다. 사실성 안내 디코딩 알고리즘은 토큰 선택 단계에서 외부 검증 또는 내부 일관성 검사를 도입하여 모델 파라미터를 수정하지 않고도 출력의 사실 정확도를 향상시킵니다.

3.1 사실성 안내 디코딩 알고리즘

표준 디코딩(탐욕 검색 또는 핵 샘플링)은 모델의 내부 확률 분포에만 의존하며, 높은 확률이지만 잘못된 경로를 따라 계속 생성하기 쉽습니다. 사실성 안내 디코딩은 다음 토큰의 확률 분포에 개입하여 외부 증거와 일치하는 토큰의 확률을 명시적으로 높이거나, 대조 방법을 사용하여 전문가 모델과 비전문가 모델 간의 신뢰도 차이를 확대합니다.

3.1.1 대조 디코딩

대조 디코딩(Contrastive Decoding)은 대규모 전문가 모델과 소규모 비전문가 모델 간의 확률 격차가 생성 품질을 나타낼 수 있다는 관찰에 기반합니다. 사실성 시나리오에서 전문가와 비전문가 모델은 사실성 토큰에 대한 확률 격차가 특히 두드러집니다. 전문가 모델의 로그 확률에서 비전문가 모델의 로그 확률을 빼서(온도로 스케일링), 대조 디코딩은 전문가 모델의 이점을 확대하고, 비전문가 모델도 저지를 수 있는 높은 확신도의 오류를 억제하여 사실 정확성을 향상시킵니다.

구현: 전문가-비전문가 대조 디코딩 및 검색 증강 생성

이 스크립트는 대조 디코딩과 동적 검색 증강 디코딩을 구현하며, 실시간 증거 검색 및 로짓 분포 조정을 포함합니다.


import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional, Tuple
import argparse
import wikipediaapi
import re # 정규표현식용

class DivergenceAwareLogitsProcessor(LogitsProcessor):
    """
    대조 디코딩 프로세서: 전문가-비전문가 모델 간의 격차를 활용하여 생성을 안내
    공식: log P_expert - alpha * log P_amateur
    """
    def __init__(self, apprentice_model, expert_tokenizer, apprentice_tokenizer,
                 divergence_alpha: float = 0.5, min_gap_beta: float = 0.0):
        self.apprentice_lm = apprentice_model
        self.expert_tokenizer = expert_tokenizer
        self.apprentice_tokenizer = apprentice_tokenizer
        self.divergence_alpha = divergence_alpha  # 비전문가 모델 가중치
        self.min_gap_beta = min_gap_beta    # 최소 격차 임계값

        # 토큰 매핑 구축 (전문가 및 비전문가 모델 어휘는 다를 수 있음)
        self.token_id_map = self._create_vocab_mapping()

    def _create_vocab_mapping(self) -> Dict[int, int]:
        """전문가 모델 토큰 ID를 비전문가 모델 토큰 ID로 매핑"""
        mapping = {}
        expert_vocab = self.expert_tokenizer.get_vocab()
        apprentice_vocab = self.apprentice_tokenizer.get_vocab()
        for expert_token_str, expert_token_id in expert_vocab.items():
            if expert_token_str in apprentice_vocab:
                apprentice_token_id = apprentice_vocab[expert_token_str]
                mapping[expert_token_id] = apprentice_token_id
        return mapping

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # scores: [batch, vocab_size] (전문가 모델 로짓)

        # 비전문가 모델 예측 가져오기
        with torch.no_grad():
            apprentice_outputs = self.apprentice_lm(input_ids)
            apprentice_logits = apprentice_outputs.logits[:, -1, :]  # [batch, apprentice_vocab]

        # 어휘 정렬 (비전문가 모델 로짓을 전문가 공간으로 매핑)
        aligned_apprentice_logits = torch.full_like(scores, float('-inf'))

        for expert_id, apprentice_id in self.token_id_map.items():
            if apprentice_id < apprentice_logits.shape[-1]:
                aligned_apprentice_logits[:, expert_id] = apprentice_logits[:, apprentice_id]

        # 대조 디코딩: 전문가 - alpha * 비전문가
        # 비전문가 모델도 높은 확신도를 가진 오류를 필터링하기 위해,
        # 전문가 - 비전문가 로짓의 격차가 beta보다 큰 토큰만 유지
        contrastive_final_scores = scores - self.divergence_alpha * aligned_apprentice_logits

        # 최소 임계값 적용 (선택 사항)
        if self.min_gap_beta > 0:
            mask = (scores - aligned_apprentice_logits) < self.min_gap_beta
            contrastive_final_scores[mask] = float('-inf') # 이 토큰들을 선택할 가능성을 크게 낮춤

        return contrastive_final_scores

class ContextualBiasLogitsProcessor(LogitsProcessor):
    """
    검색 증강 디코딩: 검색된 증거의 토큰을 선호하도록 어휘 분포를 동적으로 조정
    "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks" 논문에 기반
    """
    def __init__(self, search_retriever, faiss_index, corpus_docs, tokenizer,
                 search_frequency: int = 5, top_k_results: int = 3, bias_strength: float = 0.3):
        self.search_retriever = search_retriever  # 임베딩 모델
        self.faiss_index = faiss_index          # FAISS 인덱스
        self.corpus_docs = corpus_docs          # 원본 문서 내용
        self.tokenizer = tokenizer
        self.search_frequency = search_frequency  # N개 토큰마다 한 번 검색
        self.top_k_results = top_k_results
        self.bias_strength = bias_strength  # 검색 편향 가중치

        self.tokens_generated_count = 0
        self.cached_evidence = ""

    def fetch_relevant_evidence(self, query_text: str) -> List[str]:
        """현재 생성 내용을 기반으로 관련 증거 검색"""
        query_embedding = self.search_retriever.encode([query_text], convert_to_tensor=False)
        query_embedding = np.array(query_embedding).astype('float32')

        distances, indices = self.faiss_index.search(query_embedding, self.top_k_results)

        retrieved_evidence = []
        for idx in indices[0]:
            if idx < len(self.corpus_docs):
                retrieved_evidence.append(self.corpus_docs[idx])

        return retrieved_evidence

    def build_vocabulary_bias(self, evidence_texts: List[str], vocab_size: int) -> torch.Tensor:
        """
        증거를 기반으로 어휘 편향 구축
        증거에 나타나는 토큰의 확률 증가
        """
        bias_vector = torch.zeros(vocab_size)

        # 증거의 토큰 통계
        tokens_in_evidence = set()
        for text_frag in evidence_texts:
            tokens = self.tokenizer.encode(text_frag, add_special_tokens=False)
            tokens_in_evidence.update(tokens)

        # 증거의 토큰에 편향 부여 (TF-IDF 기반 가중치로 변경 가능)
        for token_id in tokens_in_evidence:
            if token_id < vocab_size:
                bias_vector[token_id] = 1.0 # 간단한 1.0 가중치

        return bias_vector

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        self.tokens_generated_count += 1

        # 특정 스텝마다 검색 업데이트
        if self.tokens_generated_count % self.search_frequency == 0:
            # 현재 생성된 텍스트를 질의로 디코딩
            current_partial_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            fetched_evidence = self.fetch_relevant_evidence(current_partial_text)
            self.cached_evidence = " ".join(fetched_evidence)

        if self.cached_evidence:
            # 증거 기반 편향 계산
            bias_scores = self.build_vocabulary_bias([self.cached_evidence], scores.shape[-1])
            bias_scores = bias_scores.to(scores.device)

            # 편향 적용: 증거 내 어휘의 확률 부스트
            # log-space 덧셈 사용: log(P) + lambda * bias
            adjusted_scores = scores + self.bias_strength * bias_scores
            return adjusted_scores

        return scores

class SemanticConstraintBeamSearch:
    """
    사실성 제약 빔 검색: 지식 그래프 경로 또는 특정 템플릿을 따르도록 생성 강제
    생성된 개체 관계가 KG 사실과 일치하도록 보장하는 제약 빔 검색 구현
    """
    def __init__(self, tokenizer, knowledge_graph_validator, num_beams_to_search: int = 5):
        self.tokenizer = tokenizer
        self.knowledge_graph_validator = knowledge_graph_validator  # 외부 KG 검증 함수
        self.num_beams = num_beams_to_search

    def is_valid_sequence_prefix(self, sequence_ids: List[int]) -> bool:
        """현재 시퀀스 접두사가 사실 제약을 위반하는지 확인"""
        text_generated = self.tokenizer.decode(sequence_ids, skip_special_tokens=True)
        # 외부 KG 검증 API 또는 로컬 규칙 호출
        return self.knowledge_graph_validator(text_generated)

    def generate_with_constraints(self, model, initial_input_ids: torch.LongTensor, max_output_length: int = 100):
        """
        하드 제약이 있는 빔 검색 구현
        매 단계에서 후보 빔을 유지하고 제약을 위반하는 시퀀스를 필터링
        """
        batch_size = initial_input_ids.shape[0]
        device = initial_input_ids.device

        # 초기화: 각 입력에 대해 num_beams만큼 후보 복사
        beam_path_scores = torch.zeros((batch_size, self.num_beams), device=device)
        beam_path_scores[:, 1:] = -1e9  # 초기에는 첫 번째 후보만 유효

        # 각 후보의 시퀀스 저장
        current_beam_sequences = initial_input_ids.unsqueeze(1).repeat(1, self.num_beams, 1)  # [batch, beams, seq_len]

        for step in range(max_output_length):
            # 순방향 전달을 위해 평탄화
            flat_sequences = current_beam_sequences.view(-1, current_beam_sequences.shape[-1])  # [batch*beams, seq_len]

            with torch.no_grad():
                model_outputs = model(flat_sequences)
                logits = model_outputs.logits[:, -1, :]  # [batch*beams, vocab]
                log_probabilities = F.log_softmax(logits, dim=-1)

            # [batch, beams, vocab] 형태로 재구성
            log_probabilities = log_probabilities.view(batch_size, self.num_beams, -1)

            # 새로운 후보 점수 계산: 이전 점수 + 새 토큰 로그 확률
            candidate_scores = beam_path_scores.unsqueeze(-1) + log_probabilities  # [batch, beams, vocab]
            candidate_scores = candidate_scores.view(batch_size, -1)  # [batch, beams*vocab]

            # 상위 k개(k=num_beams) 선택
            topk_values, topk_indices = torch.topk(candidate_scores, self.num_beams, dim=-1)

            # 빔 인덱스 및 토큰 인덱스 파싱
            parent_beam_indices = topk_indices // log_probabilities.shape[-1]
            next_token_indices = topk_indices % log_probabilities.shape[-1]

            # 시퀀스 업데이트 (제약 검사 처리)
            updated_sequences_list = []
            validity_mask = torch.ones_like(topk_values, dtype=torch.bool)

            for batch_idx in range(batch_size):
                batch_updated_seqs = []
                for beam_rank in range(self.num_beams):
                    parent_idx = parent_beam_indices[batch_idx, beam_rank]
                    next_token_id = next_token_indices[batch_idx, beam_rank]

                    # 부모 시퀀스 가져오고 새 토큰 추가
                    parent_seq_tokens = current_beam_sequences[batch_idx, parent_idx].tolist()
                    proposed_new_seq = parent_seq_tokens + [next_token_id.item()]

                    # 제약 검사
                    if self.is_valid_sequence_prefix(proposed_new_seq):
                        batch_updated_seqs.append(proposed_new_seq)
                    else:
                        # 제약 위반, 탈락을 위해 매우 낮은 점수 부여
                        validity_mask[batch_idx, beam_rank] = False
                        batch_updated_seqs.append(parent_seq_tokens) # 원본 시퀀스 유지 (더미)

                updated_sequences_list.append(batch_updated_seqs)

            # 상태 업데이트
            # 리스트를 텐서로 변환 (패딩 필요)
            max_len_in_batch = max(len(s) for batch_item in updated_sequences_list for s in batch_item)
            padded_sequences_tensor = torch.full((batch_size, self.num_beams, max_len_in_batch),
                                         self.tokenizer.pad_token_id,
                                         dtype=torch.long, device=device)

            for b_idx in range(batch_size):
                for i, seq_tokens in enumerate(updated_sequences_list[b_idx]):
                    padded_sequences_tensor[b_idx, i, :len(seq_tokens)] = torch.tensor(seq_tokens, device=device)

            current_beam_sequences = padded_sequences_tensor
            beam_path_scores = topk_values.masked_fill(~validity_mask, -1e9) # 유효하지 않은 빔은 낮은 점수

            # 모든 빔이 완료되었는지 확인 (EOS 생성 여부)
            # 단순화: 실제 구현에는 더 완벽한 중지 로직 필요

        # 가장 높은 점수의 시퀀스 반환
        best_sequence_ids = current_beam_sequences[:, 0, :]
        return best_sequence_ids

class GuidedGenerationEngine:
    def __init__(self, expert_lm_path: str, apprentice_lm_path: Optional[str] = None,
                 retrieval_index_path: Optional[str] = None):
        self.execution_device = "cuda" if torch.cuda.is_available() else "cpu"

        print(f"전문가 모델 로드 중: {expert_lm_path}")
        self.expert_lm = AutoModelForCausalLM.from_pretrained(
            expert_lm_path,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.main_tokenizer = AutoTokenizer.from_pretrained(expert_lm_path)
        if self.main_tokenizer.pad_token is None:
            self.main_tokenizer.pad_token = self.main_tokenizer.eos_token

        self.apprentice_lm = None
        if apprentice_lm_path:
            print(f"비전문가 모델 로드 중: {apprentice_lm_path}")
            self.apprentice_lm = AutoModelForCausalLM.from_pretrained(
                apprentice_lm_path,
                torch_dtype=torch.bfloat16,
                device_map="auto"
            )
            self.apprentice_tokenizer = AutoTokenizer.from_pretrained(apprentice_lm_path)
            if self.apprentice_tokenizer.pad_token is None:
                self.apprentice_tokenizer.pad_token = self.apprentice_tokenizer.eos_token

        # 검색 구성 요소 초기화
        self.retrieval_encoder = None
        self.faiss_data_index = None
        self.document_collection = None
        if retrieval_index_path:
            self._initialize_retrieval_system(retrieval_index_path)

    def _initialize_retrieval_system(self, index_filepath: str):
        """검색 시스템 설정"""
        print("검색 시스템 설정 중...")
        self.retrieval_encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        self.faiss_data_index = faiss.read_index(index_filepath)

        # 해당 문서 로드 (해당하는 문서 파일이 있다고 가정)
        doc_content_path = index_filepath.replace(".faiss", ".txt")
        try:
            with open(doc_content_path, 'r', encoding='utf-8') as f_docs:
                self.document_collection = [line.strip() for line in f_docs if line.strip()]
        except FileNotFoundError:
            print(f"경고: 문서 파일 {doc_content_path}을 찾을 수 없습니다. 더미 문서를 사용합니다.")
            self.document_collection = ["Placeholder document content."] * 1000
        except Exception as e:
            print(f"문서 로드 중 오류 발생: {e}. 더미 문서를 사용합니다.")
            self.document_collection = ["Placeholder document content."] * 1000


    def generate_response(self, initial_prompt: str, generation_mode: str = "contrastive",
                          max_new_tokens: int = 200, **kwargs) -> str:
        """
        생성 메서드: 다양한 디코딩 전략 지원

        generation_mode 옵션:
        - "standard": 표준 생성
        - "divergence": 대조 디코딩 (초기화 시 apprentice_lm 제공 필요)
        - "augmented": 검색 증강 디코딩 (retrieval_index_path 제공 필요)
        - "constrained": 제약 빔 검색 (knowledge_graph_validator 제공 필요)
        """
        input_token_ids = self.main_tokenizer.encode(initial_prompt, return_tensors="pt").to(self.execution_device)

        active_logits_processors = LogitsProcessorList()

        if generation_mode == "divergence" and self.apprentice_lm:
            processor = DivergenceAwareLogitsProcessor(
                self.apprentice_lm, self.main_tokenizer, self.apprentice_tokenizer,
                divergence_alpha=kwargs.get('divergence_alpha', 0.5),
                min_gap_beta=kwargs.get('min_gap_beta', 0.0)
            )
            active_logits_processors.append(processor)

        elif generation_mode == "augmented" and self.faiss_data_index:
            processor = ContextualBiasLogitsProcessor(
                self.retrieval_encoder, self.faiss_data_index, self.document_collection, self.main_tokenizer,
                search_frequency=kwargs.get('search_frequency', 5),
                bias_strength=kwargs.get('bias_strength', 0.3)
            )
            active_logits_processors.append(processor)

        # 생성
        with torch.no_grad():
            if generation_mode == "constrained":
                # 사용자 정의 빔 검색 사용
                kg_validator = kwargs.get('kg_validator', lambda x: True) # 기본적으로 항상 유효하다고 가정
                beam_search_engine = SemanticConstraintBeamSearch(
                    self.main_tokenizer, kg_validator,
                    num_beams_to_search=kwargs.get('num_beams', 5)
                )
                output_token_ids = beam_search_engine.generate_with_constraints(self.expert_lm, input_token_ids, max_new_tokens)
            else:
                output_token_ids = self.expert_lm.generate(
                    input_token_ids,
                    max_new_tokens=max_new_tokens,
                    logits_processor=active_logits_processors,
                    temperature=kwargs.get('temperature', 0.7),
                    top_p=kwargs.get('top_p', 0.9),
                    do_sample=True,
                    num_return_sequences=1
                )

        full_generated_text = self.main_tokenizer.decode(output_token_ids[0], skip_special_tokens=True)
        # 원래 프롬프트 제거
        if full_generated_text.startswith(initial_prompt):
            full_generated_text = full_generated_text[len(initial_prompt):].strip()

        return full_generated_text

def create_sample_wikipedia_index(output_directory_prefix: str = "./wiki_data_index"):
    """위키백과 검색 인덱스 예시 구축"""
    print("위키백과 인덱스 구축 중...")

    # wikipedia-api를 사용하여 샘플 문서 가져오기
    wiki_api = wikipediaapi.Wikipedia('GuidedGenerationEngine/1.0', 'en')
    page_titles_to_fetch = ["Artificial intelligence", "Machine learning", "Deep learning",
             "Natural language processing", "Transformer (machine learning)", "Large language model"]

    all_document_segments = []
    retriever_encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

    for title in page_titles_to_fetch:
        page_obj = wiki_api.page(title)
        if page_obj.exists():
            # 섹션별로 분할
            sections_text = page_obj.text.split('\n\n')
            all_document_segments.extend([s for s in sections_text if len(s) > 50]) # 50자 이상 섹션만 추가

    # 인코딩
    document_embeddings = retriever_encoder.encode(all_document_segments, convert_to_numpy=True, show_progress_bar=True)
    document_embeddings = document_embeddings.astype('float32')

    # FAISS 인덱스 구축
    embedding_dimension = document_embeddings.shape[1]
    faiss_index = faiss.IndexFlatIP(embedding_dimension)  # 내적 (코사인 유사도)
    faiss.normalize_L2(document_embeddings)  # 코사인 유사도 구현을 위한 L2 정규화
    faiss_index.add(document_embeddings)

    # 저장
    faiss.write_index(faiss_index, f"{output_directory_prefix}.faiss")
    with open(f"{output_directory_prefix}.txt", 'w', encoding='utf-8') as f_out_docs:
        for doc_segment in all_document_segments:
            f_out_docs.write(doc_segment.replace('\n', ' ') + '\n') # 줄바꿈 제거 후 한 줄씩 저장

    print(f"인덱스가 {output_directory_prefix}.faiss에 {len(all_document_segments)}개 문서와 함께 저장되었습니다.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--execution_action', choices=['generate_text', 'build_wiki_index'], default='generate_text')
    parser.add_argument('--expert_model_path', type=str, default='meta-llama/Llama-2-7b-hf') # Changed to 7b for demo
    parser.add_argument('--apprentice_model_path', type=str, default='meta-llama/Llama-2-7b-hf') # Changed to 7b for demo
    parser.add_argument('--retrieval_index_path', type=str, default='./wiki_data_index.faiss')
    parser.add_argument('--chosen_decoding_mode', type=str, default='divergence',
                       choices=['standard', 'divergence', 'augmented', 'constrained'])
    parser.add_argument('--input_prompt', type=str, default="프랑스의 수도는")
    args = parser.parse_args()

    if args.execution_action == 'build_wiki_index':
        create_sample_wikipedia_index()
    else:
        generation_engine = GuidedGenerationEngine(
            args.expert_model_path,
            args.apprentice_model_path if args.chosen_decoding_mode == 'divergence' else None,
            args.retrieval_index_path if args.chosen_decoding_mode == 'augmented' else None
        )

        # 제약 빔 검색을 위한 더미 KG 검증 함수 예시
        def dummy_kg_validator(text_input: str) -> bool:
            # 실제 KG 검증 로직이 여기에 들어갑니다.
            # 예를 들어, "프랑스 수도 파리"와 같은 특정 패턴을 확인합니다.
            if "프랑스" in text_input and "수도" in text_input and "파리" in text_input:
                return True
            return True # 기본적으로 True 반환

        kwargs_for_generation = {}
        if args.chosen_decoding_mode == 'constrained':
            kwargs_for_generation['kg_validator'] = dummy_kg_validator
            kwargs_for_generation['num_beams'] = 5 # 빔 검색을 위한 빔 수

        generated_output = generation_engine.generate_response(args.input_prompt, generation_mode=args.chosen_decoding_mode, **kwargs_for_generation)
        print(f"\n프롬프트: {args.input_prompt}")
        print(f"생성 결과 ({args.chosen_decoding_mode}): {generated_output}")

if __name__ == "__main__":
    main()
3.1.2 검색 기반 정규화

검색 기반 정규화(Retrieval-Augmented Decoding)는 검색 모듈을 디코딩 루프에 깊이 통합합니다. 표준 검색 증강 생성(RAG)이 생성 전에 한 번만 컨텍스트를 검색하는 것과 달리, 이 방법은 매 단계마다 현재 접두사와 관련된 증거를 동적으로 검색하고, 검색된 증거에 나타나는 엔티티 및 관계 토큰의 확률을 높이도록 어휘 분포를 조정합니다. 이러한 세분화된 검색 개입은 생성 과정에서 발생하는 사실 편차를 효과적으로 수정할 수 있습니다.

3.1.3 사실성 제약 빔 검색

제약 빔 검색(Constrained Beam Search)은 하드 제약을 통해 생성된 콘텐츠가 사전 정의된 지식 그래프 경로 또는 엔티티 관계 템플릿을 따르도록 보장합니다. 이 방법은 현재 생성이 지식 그래프에 대해 일치하는 상태를 추적하는 상태 머신을 유지합니다. 엔티티가 생성될 때마다 후속 토큰은 해당 엔티티가 KG에 있는 합법적인 관계 또는 속성으로 제한되어, 관계형 환각을 완전히 제거합니다.

3.2 후처리 감지-수정 파이프라인

생성 후 사실 검증 및 수정은 마지막 방어선을 제공합니다. 후처리 파이프라인을 통해 시스템은 모델 아키텍처를 변경하지 않고도 생성된 콘텐츠를 반복적으로 최적화할 수 있습니다.

3.2.1 MiniCheck 아키텍처

MiniCheck는 경량형 문장 단위 사실성 검증기로, 소형 언어 모델을 미세 조정하여 자원 제한 환경에서 세분화된 환각 감지를 수행할 수 있도록 합니다. 대규모 모델에 의존하여 완전한 생성을 검증하는 것과 달리, MiniCheck는 생성된 텍스트의 엔티티 오류, 관계 오류 및 숫자 불일치 감지에 중점을 둡니다. 그 아키텍처는 일반적으로 이중 탑 구조를 채택하여 원본 문서와 주장(claim)을 각각 인코딩하고, 교차 주의를 통해 사실 일관성 점수를 계산합니다. 감지된 환각에 대해 시스템은 수정 메커니즘을 트리거합니다. 오류 스팬을 찾아 검색된 증거를 사용하여 마스크된 재생성을 수행합니다.

구현: MiniCheck 환각 감지 및 수정 시스템

이 스크립트는 경량 모델 기반의 환각 감지, 엔티티 수준 오류 위치 파악, 그리고 RAG 기반의 반복적인 수정 생성을 구현합니다.


import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM
import spacy
from typing import List, Dict, Tuple, Optional
import numpy as np
from dataclasses import dataclass
import argparse
import json # for saving results

@dataclass
class DetectedHallucinationSpan:
    text_segment: str
    start_char: int
    end_char: int
    error_classification: str  # "entity_mismatch", "relation_incorrect", "numerical_inaccuracy", "unverifiable"
    confidence_score: float

class AccuracyVerifierLite:
    """
    경량형 환각 감지기
    문장 수준 NLI와 엔티티 수준 검증 결합
    """
    def __init__(self, model_for_nli: str = "microsoft/deberta-v3-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_for_nli)

        # 의미 일관성을 위한 이중 탑 인코더 (실제 NLI 모델로 대체 가능)
        self.encoder = AutoModel.from_pretrained(model_for_nli)

        # 분류 헤드: entailment (0), neutral (1), contradiction (2)
        self.nli_classifier_head = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size * 2, 512), # [CLS] 토큰 임베딩을 두 개 연결했다고 가정
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 3)
        )

        # 엔티티 인식
        try:
            self.nlp_entity_recognizer = spacy.load("en_core_web_sm")
        except OSError:
            print("spaCy 모델 다운로드 중...")
            import os
            os.system("python -m spacy download en_core_web_sm")
            self.nlp_entity_recognizer = spacy.load("en_core_web_sm")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder.to(self.device)
        self.nli_classifier_head.to(self.device)
        self.encoder.eval()
        self.nli_classifier_head.eval()

    def encode_text_pair(self, premise_text: str, hypothesis_text: str) -> torch.Tensor:
        """전제-가설 쌍 인코딩"""
        input_data = self.tokenizer(
            premise_text, hypothesis_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.encoder(**input_data)
            # [CLS] 토큰 표현 사용
            cls_representation = outputs.last_hidden_state[:, 0, :]  # [batch, hidden_dim]
        return cls_representation

    def check_sentence_level_consistency(self, reference_source: str, user_claim: str) -> List[Dict]:
        """문장 수준 환각 감지"""
        # 주장을 문장으로 분해
        claim_sentences = [sent.text for sent in self.nlp_entity_recognizer(user_claim).sents]

        sentence_analysis_results = []
        for sentence_segment in claim_sentences:
            # 전체 원본 문서와 단일 문장에 대해 NLI 수행
            encoded_pair_representation = self.encode_text_pair(reference_source, sentence_segment)
            # 단순화를 위해 이중 탑 표현을 연결 (실제로는 교차 인코더 사용 권장)
            # 더 나은 구현은 bart-large-mnli와 같은 전용 NLI 모델을 사용해야 함

            # 분류를 위해 두 개의 표현을 결합 (여기서는 간단한 더미)
            # 실제 NLI 모델은 직접 [CLS] 토큰을 분류함
            # 이 구현에서는 가상의 결합된 표현을 사용한다고 가정
            combined_representation_for_classifier = torch.cat([encoded_pair_representation, encoded_pair_representation], dim=-1)

            with torch.no_grad():
                logits_output = self.nli_classifier_head(combined_representation_for_classifier)
                probabilities = torch.softmax(logits_output, dim=-1)

            predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
            predicted_confidence = probabilities[0][predicted_class_idx].item()

            sentence_analysis_results.append({
                "sentence_text": sentence_segment,
                "nli_label": ["entailment", "neutral", "contradiction"][predicted_class_idx],
                "prediction_confidence": predicted_confidence,
                "is_hallucinated_sentence": predicted_class_idx == 2 or (predicted_class_idx == 1 and predicted_confidence > 0.8) # Contradiction 또는 높은 중립성 = 환각
            })

        return sentence_analysis_results

    def extract_named_entities(self, text_input: str) -> List[Dict]:
        """세분화된 검증을 위해 명명된 엔티티 추출"""
        doc = self.nlp_entity_recognizer(text_input)
        entities = []
        for ent in doc.ents:
            entities.append({
                "entity_text": ent.text,
                "entity_type": ent.label_,
                "start_offset": ent.start_char,
                "end_offset": ent.end_char
            })
        return entities

    def validate_entity_consistency(self, source_text: str, claim_text: str) -> List[DetectedHallucinationSpan]:
        """엔티티 수준 일관성 검사"""
        source_entities_map = {e['entity_text'].lower(): e for e in self.extract_named_entities(source_text)}
        claim_entities = self.extract_named_entities(claim_text)

        hallucination_instances = []

        for entity_in_claim in claim_entities:
            entity_key_lower = entity_in_claim['entity_text'].lower()
            # 단순 매칭: 실제 앱에서는 엔티티 링크 및 지식 기반 검증 사용
            if entity_key_lower not in source_entities_map:
                # 동의어 또는 지칭 가능성 확인 (의미 유사도 사용)
                is_semantically_similar = False
                for src_ent_info in source_entities_map.values():
                    # 여기에 실제 엔티티 유사도 계산 로직 필요 (예: 임베딩 기반 유사도)
                    similarity = self._calculate_entity_similarity(entity_in_claim['entity_text'], src_ent_info['entity_text'])
                    if similarity > 0.9:  # 임계값
                        is_semantically_similar = True
                        break

                if not is_semantically_similar:
                    hallucination_instances.append(DetectedHallucinationSpan(
                        text_segment=entity_in_claim['entity_text'],
                        start_char=entity_in_claim['start_offset'],
                        end_char=entity_in_claim['end_offset'],
                        error_classification="entity_mismatch",
                        confidence_score=0.8
                    ))

        return hallucination_instances

    def _calculate_entity_similarity(self, ent1_text: str, ent2_text: str) -> float:
        """엔티티 간 의미 유사도 계산 (단순화된 버전)"""
        # 실제 앱에서는 엔티티 임베딩 또는 외부 지식 그래프 사용
        return 1.0 if ent1_text.lower() == ent2_text.lower() else 0.0

    def analyze_factuality(self, source_document: str, generated_claim: str) -> Dict:
        """
        완전한 감지 프로세스
        반환: 환각 스팬 목록 및 전체 신뢰도 점수
        """
        # 문장 수준 감지
        sentence_check_results = self.check_sentence_level_consistency(source_document, generated_claim)

        # 엔티티 수준 감지
        entity_hallucinations = self.validate_entity_consistency(source_document, generated_claim)

        # 결과 병합
        all_identified_spans = entity_hallucinations

        # 모순된 문장의 엔티티를 높은 신뢰도 환각으로 표시
        for sent_res in sentence_check_results:
            if sent_res['is_hallucinated_sentence']:
                # 문장 내 엔티티를 찾아 환각 신뢰도 높이기
                sent_start = generated_claim.find(sent_res['sentence_text'])
                sent_end = sent_start + len(sent_res['sentence_text'])

                for span in all_identified_spans:
                    if span.start_char >= sent_start and span.end_char <= sent_end:
                        span.confidence_score = max(span.confidence_score, sent_res['prediction_confidence'])

        # 전체 점수 계산
        if not all_identified_spans:
            overall_trust_score = 1.0
        else:
            # 환각 비율과 신뢰도를 기반으로 계산
            hallucination_weighted_score = sum(s.confidence_score for s in all_identified_spans) / len(generated_claim.split())
            overall_trust_score = max(0.0, 1.0 - hallucination_weighted_score)

        return {
            "hallucination_spans_found": all_identified_spans,
            "sentence_level_analysis": sentence_check_results,
            "overall_factuality_rating": overall_trust_score,
            "is_any_hallucination_detected": len(all_identified_spans) > 0
        }

class SelfCorrectionWorkflow:
    """
    반복적 수정 파이프라인
    감지 -> 증거 검색 -> 수정 -> 검증
    """
    def __init__(self, detector_instance: AccuracyVerifierLite,
                 external_retriever=None,  # 외부 검색 시스템 (예: SentenceTransformer + FAISS)
                 revision_lm_name: str = "google/flan-t5-large"):
        self.detector = detector_instance
        self.retriever = external_retriever

        # 시퀀스-투-시퀀스 수정 모델
        self.reviser_model = AutoModelForSeq2SeqLM.from_pretrained(revision_lm_name)
        self.reviser_tokenizer = AutoTokenizer.from_pretrained(revision_lm_name)
        self.reviser_model.to(detector_instance.device)
        self.reviser_model.eval()

        self.max_revision_iterations = 3

    def fetch_supporting_evidence(self, query_string: str) -> List[str]:
        """지원 증거 검색 (retriever 제공 시)"""
        if self.retriever:
            # retriever가 search 메서드를 가지고 있다고 가정
            return self.retriever.search(query_string, top_k=3)
        return []

    def refine_segment(self, original_statement: str, error_span: DetectedHallucinationSpan,
                       retrieved_evidence: List[str]) -> str:
        """특정 환각 구간 수정"""
        evidence_context = " ".join(retrieved_evidence)

        # 수정 프롬프트 구성
        revision_prompt = f"""제공된 증거를 바탕으로 다음 텍스트를 수정하세요.
원본: {original_statement}
식별된 오류: "{error_span.text_segment}" ({error_span.error_classification})
증거: {evidence_context}
수정된 버전:"""

        input_tokens = self.reviser_tokenizer(revision_prompt, return_tensors="pt",
                                       max_length=1024, truncation=True).to(self.detector.device)

        with torch.no_grad():
            output_tokens = self.reviser_model.generate(
                **input_tokens,
                max_length=512,
                num_beams=4,
                early_stopping=True
            )

        revised_text = self.reviser_tokenizer.decode(output_tokens[0], skip_special_tokens=True)
        return revised_text

    def execute_full_revision(self, source_document_content: str, initial_claim: str) -> Dict:
        """
        완전한 수정 프로세스: 환각이 없거나 최대 반복 횟수에 도달할 때까지 반복 감지 및 수정
        """
        current_version_claim = initial_claim
        process_history = []

        for iteration_num in range(self.max_revision_iterations):
            print(f"--- 수정 반복 {iteration_num + 1} 시작 ---")
            # 감지
            detection_feedback = self.detector.analyze_factuality(source_document_content, current_version_claim)

            if not detection_feedback['is_any_hallucination_detected']:
                print("환각이 감지되지 않았습니다. 수정 완료.")
                break

            # 첫 번째 환각 구간에 대해 수정 (단순화: 실제로는 여러 중첩 구간을 조정해야 함)
            primary_hallucination = detection_feedback['hallucination_spans_found'][0]
            print(f"감지된 환각: '{primary_hallucination.text_segment}' ({primary_hallucination.error_classification})")

            # 해당 환각에 대한 증거 검색
            relevant_evidence = self.fetch_supporting_evidence(primary_hallucination.text_segment)
            print(f"검색된 증거: {relevant_evidence[:1]}")

            # 수정
            revised_statement = self.refine_segment(current_version_claim, primary_hallucination, relevant_evidence)
            print(f"수정된 텍스트 (부분): {revised_statement[:100]}...")

            process_history.append({
                "iteration": iteration_num + 1,
                "original_claim_at_step": current_version_claim,
                "detected_hallucination": primary_hallucination.__dict__, # dataclass to dict
                "revised_claim_at_step": revised_statement,
                "evidence_used_for_revision": relevant_evidence
            })

            current_version_claim = revised_statement
        else:
            print(f"최대 반복 횟수({self.max_revision_iterations}) 도달. 수정 종료.")


        # 최종 검증
        final_verification_result = self.detector.analyze_factuality(source_document_content, current_version_claim)

        return {
            "final_revised_text": current_version_claim,
            "revision_process_history": process_history,
            "final_factuality_score": final_verification_result['overall_factuality_rating'],
            "total_iterations_performed": len(process_history)
        }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--source_document_path', type=str, required=True, help='원문서 경로')
    parser.add_argument('--initial_claim_text', type=str, required=True, help='검증할 초기 주장 텍스트')
    parser.add_argument('--output_results_path', type=str, default='revision_output.json')
    parser.add_argument('--enable_revision_process', action='store_true', help='반복적 수정 프로세스 활성화')
    args = parser.parse_args()

    # 원본 문서 읽기
    with open(args.source_document_path, 'r', encoding='utf-8') as f:
        source_content = f.read()

    # 감지기 초기화
    print("AccuracyVerifierLite 감지기 초기화 중...")
    verifier = AccuracyVerifierLite()

    if args.enable_revision_process:
        print("반복적 수정 파이프라인 실행 중...")
        # 외부 검색기가 필요한 경우 여기에 초기화
        # 예를 들어, retriever_for_pipeline = SentenceTransformer('all-MiniLM-L6-v2')
        # pipe = SelfCorrectionWorkflow(verifier, external_retriever=retriever_for_pipeline)
        
        # 데모를 위해 검색기 없이 초기화
        pipe = SelfCorrectionWorkflow(verifier) 
        revision_output = pipe.execute_full_revision(source_content, args.initial_claim_text)

        print(f"\n최종 사실성 점수: {revision_output['final_factuality_score']:.2f}")
        print(f"총 반복 횟수: {revision_output['total_iterations_performed']}")
        print(f"최종 텍스트: {revision_output['final_revised_text']}")
    else:
        print("환각 감지 실행 중...")
        detection_output = verifier.analyze_factuality(source_content, args.initial_claim_text)

        print(f"\n환각 감지됨: {detection_output['is_any_hallucination_detected']}")
        print(f"전체 점수: {detection_output['overall_factuality_rating']:.2f}")
        print("감지된 스팬:", [s.text_segment for s in detection_output['hallucination_spans_found']])

    # 결과 저장
    with open(args.output_results_path, 'w', encoding='utf-8') as f:
        # dataclass 객체를 직렬화하기 위해 default 함수 제공
        json.dump(detection_output if not args.enable_revision_process else revision_output, f, indent=2, ensure_ascii=False, default=lambda o: o.__dict__)
    print(f"\n결과가 {args.output_results_path}에 저장되었습니다.")

if __name__ == "__main__":
    main()
3.2.2 사실 확인-RAG (Fact-Check-Then-RAG)

사실 확인-RAG(Fact-Check-Then-RAG) 전략은 먼저 SAFE 또는 MiniCheck와 같은 경량 검증기를 사용하여 초기 생성물에 대한 사실성 평가를 수행합니다. 검증에 실패하면 시스템은 검색 증강 재생성을 트리거합니다. 감지된 환각 엔티티를 쿼리 키워드로 사용하여 더 정확한 증거를 검색하고, 확장된 컨텍스트를 기반으로 내용을 다시 생성합니다. 이러한 반복적인 순환은 생성된 콘텐츠가 사실 검증을 통과하거나 최대 반복 횟수에 도달할 때까지 계속되어, 자체 수정하는 생성 프로세스를 형성합니다.

3.2.3 다중 에이전트 토론

다중 에이전트 토론(Multi-Agent Debate)은 여러 독립적인 모델 인스턴스(또는 다른 아키텍처의 모델)를 활용하여 동일한 쿼리에 대해 다양한 답변을 생성하고, 이러한 답변 간의 합의와 불일치를 비교하여 잠재적인 환각을 식별합니다. 시스템은 반복적인 토론 프로세스를 유지합니다. 각 에이전트는 다른 에이전트가 생성한 내용을 비판적으로 평가하고 잠재적인 사실 오류를 지적합니다. 여러 라운드의 토론을 거친 후, 합의도가 높은 진술을 최종 출력으로 채택하거나 논쟁적인 진술을 수동 검토가 필요하다고 표시합니다. 이 방법은 모델의 자체 검증 능력과 집단 지성을 활용하여 사실적 신뢰성을 크게 향상시킵니다.

태그: LLM 사실성 감독 미세 조정 RLHF 모델 편집 환각 완화

6월 14일 21:27에 게시됨