OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.71k stars 160 forks source link

maybe data bug with dpo trainer #294

Closed none0663 closed 1 month ago

none0663 commented 1 month ago

https://github.com/OpenLLMAI/OpenRLHF/blob/adf26867e44765a3963b4e8d249cf58a5162209c/openrlhf/trainer/dpo_trainer.py#L298 the length of the loss_masks is twice as the length of prompt_id_lens, like below shows: len(loss_masks) == 2 * len(prompt_id_lens) And it will only mask the chosen prompt ids!!!!! So the reject_ids's prompt ids isn't mask when calculating the reject_ids logp, so the dpo loss is not correct, please check.

hijkzzz commented 1 month ago

Fixed in https://github.com/OpenLLMAI/OpenRLHF/commit/6106fb22f1832d46f0072516edd28edee4034b23