Closed kisseternity closed 1 year ago
@kisseternity the sampled token id from the logits from the previous token, so i'm shifting all the logits to the right by one to line them up to do a gather. i think my logic is right, but off-by-ones are so confusing haha
@kisseternity the sampled token id from the logits from the previous token, so i'm shifting all the logits to the right by one to line them up to do a gather. i think my logic is right, but off-by-ones are so confusing haha
Well, so the action_logits are the probabilities to choose the next token. In that case, I think it's right to shift right by one during training. Another question is the action_prob includes the next prompt probabilies, while the func calculate action_log_prob = log_prob(action_prob, actions) here taking the first number of action tokens' probabilies(including the prompt probabilies but ignoring some action probabilies?). I'm still confused here, could you pls explain it? Thanks.
@kisseternity turns out there was a bug :disappointed: thank you for opening this issue
could you do a review of the last commit and see if that matches your intuition? i've also put in an extra assert; i'm not sure why the gather still worked when the other dimensions differed other than the dimension i was gathering on
@kisseternity turns out there was a bug 😞 thank you for opening this issue
could you do a review of the last commit and see if that matches your intuition? i've also put in an extra assert; i'm not sure why the gather still worked when the other dimensions differed other than the dimension i was gathering on
I think it's okay now, thanks for fixing it!
https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/bfcffe79a5d6f80fccbf5667b263f660b41dda30/palm_rlhf_pytorch/ppo.py#L612
Hello, as the action_logits originally is indeed the promt added with the response logits, so I wonder shifting along sequence dimension by 1 is really the right thing to do or not. Shouldn't it shift to left by the prompt length so that only the action_logits left here? The same thing happens to the line in the learn function.