lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.7k stars 668 forks source link

A bug in the implementation of the top-p sampling #60

Open allblueJT opened 1 month ago

allblueJT commented 1 month ago

https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/6b02ee329106baff78e293afa7d1d2e6dd4e5ca2/palm_rlhf_pytorch/utils.py#L60

Using the sorted indices to index the sorted indices does not make sense. I think it may be return logits.scatter(1, sorted_indices, sorted_logits)