Open Magicsmx opened 11 months ago
modeling_chatglm.py
的960行左右,shift_logits
是模型输出的每个单词的logit,截掉最后一个词,shift_labels
是真实标签,截掉第一个词。
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
里面的运算大概是这样,手动验证 loss 是一样的:
mask = shift_labels != -100
shift_logits = shift_logits[mask]
shift_labels = shift_labels[mask]
shift_probs = F.softmax(shift_logits, -1)
shift_onehots = F.one_hot(shift_labels, shift_probs.shape[-1])
loss = (-shift_onehots * torch.log(shift_probs)).sum(-1).mean()
Is there an existing issue for this?
Current Behavior
求问 在做基于ChatGLM2-6B的Ptuning v2的微调任务时 损失函数的公式是什么样子的
Expected Behavior
无
Steps To Reproduce
RT
Environment
Anything else?
No response