Closed Congcong-Song closed 7 months ago
def compute_advantages(self, results, num_samples): old_values = results['generated_value'] rewards = results['rewards/penalized'] mask = results['generated_attention_mask'] # (B, KL) with torch.no_grad(): if self.args['ppo']['whiten_rewards']: whitened_rewards = whiten(rewards, mask, shift_mean=False, accelerator=self.accelerator) else: whitened_rewards = rewards lastgaelam = 0 advantages_reversed = [] gen_length = mask.sum(dim=1).max().item() for t in reversed(range(gen_length)): nextvalues = old_values[:, t + 1] if t < gen_length - 1 else 0.0 delta = whitened_rewards[:, t] + self.args['ppo']['gamma'] * nextvalues - old_values[:, t] lastgaelam = delta + self.args['ppo']['gamma'] * self.args['ppo']['lam'] * lastgaelam advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1], dim=1) advantages = F.pad(advantages, (0, whitened_rewards.size(1) - gen_length), value=0.0) returns = advantages + old_values whitened_advantages = advantages.detach() whitened_advantages = whiten(advantages, mask, accelerator=self.accelerator).detach() results['whitened_advantages'] = whitened_advantages results['returns'] = returns
这里的lastgaelam指什么呢?
This follows the advantage estimation function defined in section 2 of the paper.
这里的lastgaelam指什么呢?