lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

KL divergence loss #38

Closed taynoel closed 1 year ago

taynoel commented 1 year ago

Hi, thanks for the great repo I have a question, In the function masked_kl_div of ppo.py, shouldnt the calculation be prob1*(log(prob1) - log(prob2))? The calculation in the code is a negative KL loss that is to be maximized instead of minimized (as assumed by the code).

lucidrains commented 1 year ago

@taynoel84 🤦 thank you for finding this error!