Closed bowuaaa closed 1 year ago
Great question- I think the answer is pretty subtle (and I'm still thinking through this myself). First thing to keep in mind is that DPO (and reward modeling in RLHF in general, not just with DPO), all we care about is maximizing the reward margin between chosen and rejected; the absolute reward doesn't actually matter.
As for why the log probs on the chosen response actually decrease over time for DPO, consider this:
After the SFT stage, which for this experiment was just training on the preferred responses in the preference dataset, we're roughly in a local maximum; we can't assign any higher probability to the chosen responses (since that's exactly what SFT was optimizing for). By definition, that means any change to our policy will lower the average log prob assigned to chosen things (and that the expected value of the first term of the DPO update gradient is zero). Since we're optimizing for margin (a different objective than SFT), that does happen. But the average reward/logp for rejected things will go down faster.
Remember, in RLHF, we just want to find a policy that achieves high reward under a reward function that discriminates between good and bad trajectories accurately. The absolute rewards don't matter.
Let me know if this makes sense, or if you think I've missed something.
Understand your point. From the wandb curves, the margin is indeed increasing.
My understanding is that DPO's optimization goal is to maximize the margin, while RLHF's optimization goal is to find a policy that achieves high reward (assuming the reward model is accurate enough). The optimization goals of the two are different.
Essentially, DPO's optimization goal is not aligned with human preferences, but it might be an easier-to-optimize proxy loss, which is relatively more stable.
Actually, that's not quite right.
We can compare conventional RLHF pipelines with DPO by analyzing what they do for the two stages of reward modeling from preferences and policy optimization.
During reward modeling, both DPO and conventional RLHF pipelines minimize the bradley terry (classification) loss function on preferences to learn a reward function that discriminates well (high margin) between the chosen and rejected completion. The only difference is that conventional RLHF uses an arbitrary LM to predict a scalar reward, while DPO uses a specific parameterization of the reward function (specifically, parameterizing the reward in terms of the policy, so there is actually no separate explicit reward model).
During policy optimization, "conventional" RLHF runs PPO on the learned reward model to approximate the optimal policy for that learned reward function. In contrast, for DPO, our special parameterization of the reward model means we can extract the optimal policy directly from our reward parameterization, without doing any further training or compute.
So both methods are optimizing the same objective (maximize expected reward under the learned reward subject to KL constraint); the difference is that DPO optimizes this objective exactly, while it's actually PPO that only finds an approximation.
Hope this helps!
Great explanation--"During reward modeling, both DPO and conventional RLHF pipelines minimize the bradley terry (classification) loss function on preferences to learn a reward function that discriminates well (high margin) between the chosen and rejected completion."
You are right. I misunderstood. The optimization goals of the two are indeed the same.
BTW, Do you think bradley terry model has strictly aligned with human preference?
Thanks
BTW, Do you think bradley terry model has strictly aligned with human preference
I'm not an expert in human preference modeling, but the fact that the BT model assumes each object/response gets a single score (rather than a complex multimodal distribution of scores, since human preferences are diverse) might be one way it is restrictive.
Have reviewed the wandb training curves provided, and I have a question: why do prob_eval(train)/chosen and rewards_eval(train)/chosen gradually decrease? I originally thought that these two metrics would gradually increase. Did I misunderstand something?