lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

Calculating the kl loss seems has a mistake. #43

Closed Nightbringers closed 1 year ago

Nightbringers commented 1 year ago

code: kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight

I think old_action_probs should be y(true), action_probs should be y(pred),i think the right code should be this: kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight

Am I right?or Im misunderstanding.

lucidrains commented 1 year ago

no i think you may be correct, will make the change! 🙏