Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
7.7k
stars
668
forks
source link
A bug in the implementation of the top-p sampling #60
Open
allblueJT opened 1 month ago
https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/6b02ee329106baff78e293afa7d1d2e6dd4e5ca2/palm_rlhf_pytorch/utils.py#L60
Using the sorted indices to index the sorted indices does not make sense. I think it may be
return logits.scatter(1, sorted_indices, sorted_logits)