DeepSpeed에서 PPO 알고리즘 구현 및 코드 분석

기본 개념

가치(Value): (1) 액터(Actor)가 현재 상태 또는 상태-액션 쌍에 대해 예측하는 장기 누적 보상의 추정치 (2) 우위 함수(Advantage Function) 계산에 사용되며, 액터 모델의 매개변수 업데이트를 지도함 (3) 크리틱(Critic) 모델을 통해 학습됨

보상(Reward): 환경이 액터가 특정 시점에 액션을 수행한 후 제공하는 즉각적인 피드백 신호

구체적인 코드 구현

우위(advantage) 계산 코드는 다음과 같으며, 우위는 다음 가치(next-value)에서 현재 가치(value)를 뺀 것으로 계산됩니다

def calculate_advantages_and_returns(self, value_estimates, reward_signals, start_index):
    # CarperAI/trlx 저장소에서 채택된 코드 (https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134)
    last_gae_value = 0
    advantages_reversed = []
    sequence_length = reward_signals.size()[-1]
    
    for t in reversed(range(start_index, sequence_length)):
        next_values = value_estimates[:, t + 1] if t < sequence_length - 1 else 0.0
        delta = reward_signals[:, t] + self.gamma * next_values - value_estimates[:, t]
        last_gae_value = delta + self.gamma * self.lam * last_gae_value
        advantages_reversed.append(last_gae_value)
    
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + value_estimates[:, start_index:]
    return advantages.detach(), returns

크리틱 손실(critic_loss) 계산 코드는 새로운 가치와 이전 가치 간의 차이를 판단하는 로직을 보여줍니다

    def calculate_critic_loss(self, current_values, previous_values, target_returns, attention_mask):
        ## 가치 손실 계산
        clipped_values = torch.clamp(
            current_values,
            previous_values - self.cliprange_value,
            previous_values + self.cliprange_value,
        )
        
        if self.compute_fp32_loss:
            current_values = current_values.float()
            clipped_values = clipped_values.float()
            
        vf_loss1 = (current_values - target_returns)**2
        vf_loss2 = (clipped_values - target_returns)**2
        vf_loss = 0.5 * torch.sum(
            torch.max(vf_loss1, vf_loss2) * attention_mask) / attention_mask.sum()
        return vf_loss

액터 손실(actor_loss) 계산 코드는 새로운 로짓과 이전 로짓 간의 거리를 계산한 후, 우위(advantage)를 곱하여 최종 손실을 얻는 과정을 보여줍니다

def calculate_actor_loss(self, current_logprobs, previous_logprobs, advantages, attention_mask):
    ## 정책 그래디언트 손실 계산
    log_ratio = (current_logprobs - previous_logprobs) * attention_mask
    ratio = torch.exp(log_ratio)
    
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * attention_mask) / attention_mask.sum()
    return pg_loss

보상 계산 코드는 참조 모델(ref-model)의 출력과 액터 모델의 출력 간의 거리를 계산하여 보상을 얻는 방법을 보여줍니다

    def generate_rewards(self, prompts, current_log_probs, reference_log_probs, reward_score, action_mask):
        # KL 발산 추정치 계산
        kl_divergence_estimate = -self.kl_ctl * (current_log_probs - reference_log_probs)
        rewards = kl_divergence_estimate
        
        start = prompts.shape[1] - 1
        ends = start + action_mask[:, start:].sum(1) + 1
        reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value)
        batch_size = current_log_probs.shape[0]
        
        for j in range(batch_size):
            rewards[j, start:ends[j]][-1] += reward_clip[j]

        return rewards

전체 코드: https://github.com/deepspeedai/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py

태그: DeepSpeed PPO 강화학습 RLHF python

6월 23일 17:59에 게시됨