lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

KL divergence loss #38

Closed taynoel84 closed 1 year ago

taynoel84 commented 1 year ago

Hi, thanks for the great repo I have a question, In the function masked_kl_div of ppo.py, shouldnt the calculation be prob1*(log(prob1) - log(prob2))? The calculation in the code is a negative KL loss that is to be maximized instead of minimized (as assumed by the code).

lucidrains commented 1 year ago

@taynoel84 🤦 thank you for finding this error!