OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.73k stars 164 forks source link

DPO Loss #235

Closed paulcx closed 4 months ago

paulcx commented 4 months ago

I have a question, in the calculation of DPO's loss, should the labels for the prompt portion in chosen_token and reject_token be set to -100, meaning that they are not involved in the calculation of the loss? This method of loss calculation is usually used in the sft training process.

hijkzzz commented 4 months ago

the official code of DPO does not seem to do so

atebbifakhr commented 4 months ago

Actually it does: https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90-L115

paulcx commented 4 months ago

The _get_batch_logps function calculates the log probabilities of the non-masked tokens, optionally returning either the sum or the average of these values for each sequence in the batch. The masking mechanism allows for the exclusion of certain tokens (indicated by -100) from the probability calculation, but the function itself does not distinguish between different parts of the input sequence such as prompts and answers. If to do so, we could calculate labels in datasets and set ignore token -100 for each prompt in chosen tokens and reject tokens.

hijkzzz commented 4 months ago

@atebbifakhr @paulcx I observed loss function of DPO, whether mask prompt or not, does not have any effect on the final gradient

image

paulcx commented 4 months ago

Hi @hijkzzz Would you mind take look at the official code here? I guess @atebbifakhr was correct.

hijkzzz commented 4 months ago

Hi @hijkzzz Would you mind take look at the official code here?

I see your point, but it seems that the logits in the prompt section will be canceled out by the opposite signs of the chosen and rejected samples due to the presence of the log function. I can add this mask to make the implementation aligned official implementation.

paulcx commented 4 months ago

That's great and I'm willing to test it after that.

paulcx commented 4 months ago

Very limited pair testing and the green one is without applying this pr

image

louieworth commented 1 week ago

@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?

louieworth commented 1 week ago

Hi @hijkzzz Would you mind take look at the official code here?

I see your point, but it seems that the logits in the prompt section will be canceled out by the opposite signs of the chosen and rejected samples due to the presence of the log function. I can add this mask to make the implementation aligned official implementation.

@hijkzzz has this alignment to the official codebase been made? I can submit a merge for the code.

paulcx commented 1 week ago

@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?

yes

louieworth commented 1 week ago

@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?

yes

"Yes" mean "no effect" or "effect" on the final performance?

paulcx commented 1 week ago

@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?

yes

"Yes" mean "no effect" or "effect" on the final performance?

https://x.com/corbtt/status/1806336011804484017

louieworth commented 1 week ago

Thanks for sharing this tweet.