Closed javismiles closed 4 months ago
Hello @javismiles,
The attached lines intend to mask the loss of prompts while calculating the NLL loss, which is $L{SFT}$ in the paper. And $L{SFT}$ is calculated for the chosen responses only. For that reason, we are only handling the positive ones.
Hope it helps!
Thank you.
yes that makes total sense, thank you :)
In these lines below, why are we masking just the positive labels with the mask.logical_not() to discard the prompt tokens, but we don't do the same with the negative labels? what's the reason to only do it with the positive ones? thank you very much
` ### Discard the prompt tokens in NLL loss if true if self.disable_prompt_loss: mask = inputs['attention_mask'] inputs['positive_attention_mask'] pos_labels = pos_labels mask.logical_not() pos_labels[pos_labels == 0] = self.pad ##################################################