xfactlab / orpo

Official repository for ORPO
Apache License 2.0
421 stars 39 forks source link

Discarding the prompt tokens only with the positive labels and not with the negative ones #32

Closed javismiles closed 4 months ago

javismiles commented 5 months ago

In these lines below, why are we masking just the positive labels with the mask.logical_not() to discard the prompt tokens, but we don't do the same with the negative labels? what's the reason to only do it with the positive ones? thank you very much

` ### Discard the prompt tokens in NLL loss if true if self.disable_prompt_loss: mask = inputs['attention_mask'] inputs['positive_attention_mask'] pos_labels = pos_labels mask.logical_not() pos_labels[pos_labels == 0] = self.pad ##################################################

    neg_labels[neg_labels == self.pad] = -100
    pos_labels[pos_labels == self.pad] = -100`
jiwooya1000 commented 5 months ago

Hello @javismiles,

The attached lines intend to mask the loss of prompts while calculating the NLL loss, which is $L{SFT}$ in the paper. And $L{SFT}$ is calculated for the chosen responses only. For that reason, we are only handling the positive ones.

Hope it helps!

Thank you.

javismiles commented 4 months ago

yes that makes total sense, thank you :)