huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.93k stars 1.25k forks source link

`OnPolicyConfig`: Change `non_eos_penalty` to be more clearly documented and consistent across different trainers #2012

Closed RylanSchaeffer closed 2 weeks ago

RylanSchaeffer commented 2 months ago

Feature request

The OnPolicyConfig has a flag: non_eos_penalty: bool = False, which is described as: """whether to penalize responses that do not contain stop_token_id"""

I interpreted this to mean that if a response doesn't contain stop_token_id, then an additional penalty will be subtracted. I had this interpretation because in the context of length penalization, the penalty is subtracted from the reward model's score.

But TRL has three trainers PPOv2Trainer, RLOOTrainer and OnlineDPOTrainer with inconsistent behavior.

Replace Behavior: PPOv2Trainer and RLOOTrainer replace non-EOS outputs' scores with the penalty:

PPOv2Trainer: https://github.com/huggingface/trl/blob/main/trl/trainer/ppov2_trainer.py#L375-L377

RLOOTrainer: https://github.com/huggingface/trl/blob/main/trl/trainer/rloo_trainer.py#L332-L334

Subtract Behavior: OnlineDPOTrainer instead subtracts the penalty from non-EOS outputs' scores:

OnlineDPOTrainer: https://github.com/huggingface/trl/blob/main/trl/trainer/online_dpo_trainer.py#L340-L342

Requests:

  1. Please update the documentation to be clear about how the non-EOS penalty will be used
  2. Please either (a) make the behavior consistent between Trainers or (b) use different flags

Motivation

Yes, this problem just cost me a day.

Your contribution

I would be happy to fix this if the core contributors tell me which solution they prefer

RylanSchaeffer commented 2 months ago

I realized that replacing the score might even be nonsensical. Reward models' outputs are shift-invariant, so if a reward model outputs scores in [-10, -5], then a replaced score of -1 is fantastic and the policy model is rewarded for this misbehavior

qgallouedec commented 2 months ago

That's a very good point, that I agree with. That's why we've chosen to use missing_eos_penalty in the recently implemented Online DPO (as you mentioned):

https://github.com/huggingface/trl/blob/1f6a1d2f9afc53697bba79ac68a72a1d0c4af666/trl/trainer/online_dpo_trainer.py#L340-L342

I would opt for a generalised use of missing_eos_penalty. But I'd like to make sure there's no regression. Is it possible to have a curve to compare the two options?

Thank you for your proposing your contribution. I'll be very happy to review a PR for this @RylanSchaeffer

RylanSchaeffer commented 2 months ago

I'd be happy to work on this!

If I can first clarify, when you say, "I would opt for a generalised use of missing_eos_penalty", can you please clarify what you mean by "generalised"? Do you want the user to be able to optionally choose to either replace or subtract?

RylanSchaeffer commented 1 month ago

Update: We are currently working on a PR here: https://github.com/huggingface/trl/pull/2033

qgallouedec commented 1 month ago

If I can first clarify, when you say, "I would opt for a generalised use of missing_eos_penalty", can you please clarify what you mean by "generalised"? Do you want the user to be able to optionally choose to either replace or subtract?

No, I meant generalize = having missing_eos_penalty (substract) instead of non_eos_penalty (replace) for all trainers

qgallouedec commented 2 weeks ago

Solved in #2033