allenai / FineGrainedRLHF

Apache License 2.0
255 stars 21 forks source link

计算advantages时lastgaelam是指什么? #8

Closed Congcong-Song closed 7 months ago

Congcong-Song commented 1 year ago
def compute_advantages(self, results, num_samples):

    old_values = results['generated_value']
    rewards = results['rewards/penalized']
    mask = results['generated_attention_mask'] # (B, KL)

    with torch.no_grad():
        if self.args['ppo']['whiten_rewards']:
            whitened_rewards = whiten(rewards, mask, shift_mean=False, accelerator=self.accelerator)
        else:
            whitened_rewards = rewards

        lastgaelam = 0
        advantages_reversed = []
        gen_length = mask.sum(dim=1).max().item()
        for t in reversed(range(gen_length)):
            nextvalues = old_values[:, t + 1] if t < gen_length - 1 else 0.0
            delta = whitened_rewards[:, t] + self.args['ppo']['gamma'] * nextvalues - old_values[:, t]
            lastgaelam = delta + self.args['ppo']['gamma'] * self.args['ppo']['lam'] * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        advantages = F.pad(advantages, (0, whitened_rewards.size(1) - gen_length), value=0.0)
        returns = advantages + old_values

        whitened_advantages = advantages.detach()
        whitened_advantages = whiten(advantages, mask, accelerator=self.accelerator).detach()

    results['whitened_advantages'] = whitened_advantages
    results['returns'] = returns

这里的lastgaelam指什么呢?

ellenmellon commented 11 months ago

This follows the advantage estimation function defined in section 2 of the paper.