기본 개념
가치(Value): (1) 액터(Actor)가 현재 상태 또는 상태-액션 쌍에 대해 예측하는 장기 누적 보상의 추정치 (2) 우위 함수(Advantage Function) 계산에 사용되며, 액터 모델의 매개변수 업데이트를 지도함 (3) 크리틱(Critic) 모델을 통해 학습됨
보상(Reward): 환경이 액터가 특정 시점에 액션을 수행한 후 제공하는 즉각적인 피드백 신호
구체적인 코드 구현
우위(advantage) 계산 코드는 다음과 같으며, 우위는 다음 가치(next-value)에서 현재 가치(value)를 뺀 것으로 계산됩니다
def calculate_advantages_and_returns(self, value_estimates, reward_signals, start_index):
# CarperAI/trlx 저장소에서 채택된 코드 (https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134)
last_gae_value = 0
advantages_reversed = []
sequence_length = reward_signals.size()[-1]
for t in reversed(range(start_index, sequence_length)):
next_values = value_estimates[:, t + 1] if t < sequence_length - 1 else 0.0
delta = reward_signals[:, t] + self.gamma * next_values - value_estimates[:, t]
last_gae_value = delta + self.gamma * self.lam * last_gae_value
advantages_reversed.append(last_gae_value)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + value_estimates[:, start_index:]
return advantages.detach(), returns
크리틱 손실(critic_loss) 계산 코드는 새로운 가치와 이전 가치 간의 차이를 판단하는 로직을 보여줍니다
def calculate_critic_loss(self, current_values, previous_values, target_returns, attention_mask):
## 가치 손실 계산
clipped_values = torch.clamp(
current_values,
previous_values - self.cliprange_value,
previous_values + self.cliprange_value,
)
if self.compute_fp32_loss:
current_values = current_values.float()
clipped_values = clipped_values.float()
vf_loss1 = (current_values - target_returns)**2
vf_loss2 = (clipped_values - target_returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * attention_mask) / attention_mask.sum()
return vf_loss
액터 손실(actor_loss) 계산 코드는 새로운 로짓과 이전 로짓 간의 거리를 계산한 후, 우위(advantage)를 곱하여 최종 손실을 얻는 과정을 보여줍니다
def calculate_actor_loss(self, current_logprobs, previous_logprobs, advantages, attention_mask):
## 정책 그래디언트 손실 계산
log_ratio = (current_logprobs - previous_logprobs) * attention_mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * attention_mask) / attention_mask.sum()
return pg_loss
보상 계산 코드는 참조 모델(ref-model)의 출력과 액터 모델의 출력 간의 거리를 계산하여 보상을 얻는 방법을 보여줍니다
def generate_rewards(self, prompts, current_log_probs, reference_log_probs, reward_score, action_mask):
# KL 발산 추정치 계산
kl_divergence_estimate = -self.kl_ctl * (current_log_probs - reference_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1) + 1
reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value)
batch_size = current_log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
return rewards
전체 코드:
https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py