lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.7k stars 666 forks source link

Is this shift right for the action logits? #31

Closed kisseternity closed 1 year ago

kisseternity commented 1 year ago

https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/bfcffe79a5d6f80fccbf5667b263f660b41dda30/palm_rlhf_pytorch/ppo.py#L612

Hello, as the action_logits originally is indeed the promt added with the response logits, so I wonder shifting along sequence dimension by 1 is really the right thing to do or not. Shouldn't it shift to left by the prompt length so that only the action_logits left here? The same thing happens to the line in the learn function.

lucidrains commented 1 year ago

@kisseternity the sampled token id from the logits from the previous token, so i'm shifting all the logits to the right by one to line them up to do a gather. i think my logic is right, but off-by-ones are so confusing haha

kisseternity commented 1 year ago

@kisseternity the sampled token id from the logits from the previous token, so i'm shifting all the logits to the right by one to line them up to do a gather. i think my logic is right, but off-by-ones are so confusing haha

Well, so the action_logits are the probabilities to choose the next token. In that case, I think it's right to shift right by one during training. Another question is the action_prob includes the next prompt probabilies, while the func calculate action_log_prob = log_prob(action_prob, actions) here taking the first number of action tokens' probabilies(including the prompt probabilies but ignoring some action probabilies?). I'm still confused here, could you pls explain it? Thanks.

lucidrains commented 1 year ago

@kisseternity turns out there was a bug :disappointed: thank you for opening this issue

could you do a review of the last commit and see if that matches your intuition? i've also put in an extra assert; i'm not sure why the gather still worked when the other dimensions differed other than the dimension i was gathering on

kisseternity commented 1 year ago

@kisseternity turns out there was a bug 😞 thank you for opening this issue

could you do a review of the last commit and see if that matches your intuition? i've also put in an extra assert; i'm not sure why the gather still worked when the other dimensions differed other than the dimension i was gathering on

I think it's okay now, thanks for fixing it!