Closed paulcx closed 4 months ago
the official code of DPO does not seem to do so
The _get_batch_logps
function calculates the log probabilities of the non-masked tokens, optionally returning either the sum or the average of these values for each sequence in the batch. The masking mechanism allows for the exclusion of certain tokens (indicated by -100) from the probability calculation, but the function itself does not distinguish between different parts of the input sequence such as prompts and answers. If to do so, we could calculate labels in datasets and set ignore token -100 for each prompt in chosen tokens and reject tokens.
@atebbifakhr @paulcx I observed loss function of DPO, whether mask prompt or not, does not have any effect on the final gradient
Hi @hijkzzz Would you mind take look at the official code here? I guess @atebbifakhr was correct.
Hi @hijkzzz Would you mind take look at the official code here?
I see your point, but it seems that the logits in the prompt section will be canceled out by the opposite signs of the chosen and rejected samples due to the presence of the log function. I can add this mask to make the implementation aligned official implementation.
That's great and I'm willing to test it after that.
Very limited pair testing and the green one is without applying this pr
@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?
Hi @hijkzzz Would you mind take look at the official code here?
I see your point, but it seems that the logits in the prompt section will be canceled out by the opposite signs of the chosen and rejected samples due to the presence of the log function. I can add this mask to make the implementation aligned official implementation.
@hijkzzz has this alignment to the official codebase been made? I can submit a merge for the code.
@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?
yes
@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?
yes
"Yes" mean "no effect" or "effect" on the final performance?
@paulcx What does the curve represent? Is there any experimental difference between w/ and w/o masks on prompts?
yes
"Yes" mean "no effect" or "effect" on the final performance?
Thanks for sharing this tweet.
I have a question, in the calculation of DPO's loss, should the labels for the prompt portion in chosen_token and reject_token be set to -100, meaning that they are not involved in the calculation of the loss? This method of loss calculation is usually used in the sft training process.