GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
780 stars 49 forks source link

bug 计算sft损失的时候 #48

Open shyoulala opened 7 months ago

shyoulala commented 7 months ago

计算sft损失的时候label和logits貌似没有shift,是我理解有问题吗? 应该是new_logits = logits[:,:-1,:]

image

shyoulala commented 7 months ago

这样: image

GanjinZero commented 7 months ago

labels在DataCollatorForSupervisedDataset里shift过了