Closed none0663 closed 1 month ago
https://github.com/OpenLLMAI/OpenRLHF/blob/adf26867e44765a3963b4e8d249cf58a5162209c/openrlhf/trainer/dpo_trainer.py#L298 the length of the loss_masks is twice as the length of prompt_id_lens, like below shows: len(loss_masks) == 2 * len(prompt_id_lens) And it will only mask the chosen prompt ids!!!!! So the reject_ids's prompt ids isn't mask when calculating the reject_ids logp, so the dpo loss is not correct, please check.
len(loss_masks) == 2 * len(prompt_id_lens)
Fixed in https://github.com/OpenLLMAI/OpenRLHF/commit/6106fb22f1832d46f0072516edd28edee4034b23
https://github.com/OpenLLMAI/OpenRLHF/blob/adf26867e44765a3963b4e8d249cf58a5162209c/openrlhf/trainer/dpo_trainer.py#L298 the length of the loss_masks is twice as the length of prompt_id_lens, like below shows:
len(loss_masks) == 2 * len(prompt_id_lens)
And it will only mask the chosen prompt ids!!!!! So the reject_ids's prompt ids isn't mask when calculating the reject_ids logp, so the dpo loss is not correct, please check.