CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

Attention mask when calculating log ratio for PPO #582

Open kmy17518 opened 1 year ago

kmy17518 commented 1 year ago

Hi, I have a quesiton about calculating log ratio for PPO. I'm very new to this area and I would be really grateful if you can help me.

In accelerate_ppo_trainer.py, def make_experience, line 457 log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]

but according to # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled, so shouldn't it be attention_mask[:, 1:] ?

in accelerate_ppo_trainer.py, def loss, line 188

logprobs, values_pred, mask = (
                logprobs[:, start:end],
                values_pred[:, start:end],
                attention_mask[:, start + 1 : end + 1],
            )

Here I think attention mask is shifted the correct way. So why is it different in def make_experience?

Thanks for your help in advance!