lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

Confusion about KL divergence calculation for human feedback policies #41

Closed dwyzzy closed 1 year ago

dwyzzy commented 1 year ago

Hi, thanks for the great work. I also have a question about KL divergence loss. In papers like Learning to summarize from human feedback, the KL item for human feedback policies seems to be the KL divergence between $\pi^{RL}$ and $\pi^{SFT}$, while in this repo the code

kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight

seems to be the KL divergence between $\pi^{new}$ and $\pi^{old}$.

Does there exist something wrong with the code, or have I made some mistakes? Thank you.

lucidrains commented 1 year ago

my understanding is that pi rl corresponds to the new action probs while pi sft corresponds to the old. let me know through a pull request what you think the necessary changes should be, if you believe different

ehion commented 1 year ago

@lucidrains i am curious why you use gumbel sample to collect(add gumbel noise and get argmax index) action for a given state(prompt)? why not Categorical?eg:dist = Categorical(logist),action = dist.sample()

ehion commented 1 year ago

pi[old] and pi are same model with different parameters in my view, however, pi[old] updates slower compared to pi。pi[old] is just used for important sampling in RL(eg:ppo)。In instruct-gpt,pi[old] and pi are SFT model with different model parameters。

lucidrains commented 1 year ago

@lucidrains i am curious why you use gumbel sample to collect(add gumbel noise and get argmax index) action for a given state(prompt)? why not Categorical?eg:dist = Categorical(logist),action = dist.sample()

it is equivalent

pi[old] and pi are same model with different parameters in my view, however, pi[old] updates slower compared to pi。pi[old] is just used for important sampling in RL(eg:ppo)。In instruct-gpt,pi[old] and pi are SFT model with different model parameters。

yea, and old action prob is sampled from pi[old] and new action prob from pi[rl]. feel free to correct if i'm mistakened. also provide (pseudo)code, as it would be clearer than english

dwyzzy commented 1 year ago

Hi. I think the kl divergence for human feedback policies (i.e. $D{KL} (\pi^{RL} || \pi^{SFT}) $) should be added to the reward calculation in this repo. The final reward should be $reward{total} = reward{reward model} - kl{penalty}$

lucidrains commented 1 year ago

@dwyzzy yea, i just noticed that on rereading

do you think it makes a difference whether it is subtracted from the rewards rather than just added as an auxiliary loss?

lucidrains commented 1 year ago

i am by no means knowledgeable with the RL field

lucidrains commented 1 year ago

@dwyzzy hey, i decided to make the change in 0.2.0

let me know if that makes more sense

lucidrains commented 1 year ago

@dwyzzy yea, i just noticed that on rereading

do you think it makes a difference whether it is subtracted from the rewards rather than just added as an auxiliary loss?

if there are any RL experts in the room, now is the time to shine

dwyzzy commented 1 year ago

@lucidrains Hi. I agree that these two approaches are similar, where the kl divergence is used to keep the newest RL policy from deviating too much from the original SFT model. From my point of view, 0.2.0 is more closer to these RLHF papers (add the kl divergence penalty of SFT model and RL policy to the reward). Thank you again for the great work!

lucidrains commented 1 year ago

@dwyzzy ok sounds good, i'm really curious what the difference is, if any

do email me if you end up trying both approaches

DarrenRuan commented 11 months ago

I think that is a quite interesting point. I believe in the original PPO rl algo, the kl divergence should be calculated between $\pik$ at iteraction k and $\pi{k+1}$. In other words, sample the current policy $\pik$ then update the policy to find $\pi{k+1}$.

Reference: https://spinningup.openai.com/en/latest/algorithms/ppo.html

However, in rlhf, it seems that the KL divergence is calculated between 1 instead of 2. Any idea why is that the case?

  1. $\pi^{RL}$ and $\pi^{SFT}$
  2. $\pi^{RL}_{k+1}$ and $\pi^{RL}_{k}$

Or does it mean that actually there are two different KL divergences. 1 is added to the reward directly. 2 is still there for the PPO update?

dwyzzy commented 11 months ago

I think that is a quite interesting point. I believe in the original PPO rl algo, the kl divergence should be calculated between πk at iteraction k and πk+1. In other words, sample the current policy πk then update the policy to find πk+1.

Reference: https://spinningup.openai.com/en/latest/algorithms/ppo.html

However, in rlhf, it seems that the KL divergence is calculated between 1 instead of 2. Any idea why is that the case?

  1. πRL and πSFT
  2. πk+1RL and πkRL

Or does it mean that actually there are two different KL divergences. 1 is added to the reward directly. 2 is still there for the PPO update?

I think it's what you said at the end: there are two different KL divergences.

(1) The KL divergence between $\pi^{RL}$ and $\pi^{SFT}$ is used to control the PPO model so that it is not too far away from the SFT model (DeepSpeed-Chat code about this). When the PPO model is updated far away from the SFT model, it may lose some capabilities for NLP.

(2) The KL divergence between $\pi^{RL}_{k+1}$ and $\pi^{RL}k$ is for PPO update. In PPO-Clip, it is the ratio clip of the old policy $\pi^{RL}_k$ and the new policy $\pi^{RL}\{k+1}$ (DeepSpeed-Chat code about this).

You are welcome to point out the mistakes in my comment if I have made some~