THUDM / ChatGLM2-6B

ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型
Other
15.7k stars 1.85k forks source link

[BUG/Help] 求问 在做基于ChatGLM2-6B的Ptuning v2的微调任务时 损失/目标函数的公式是什么样子的 #642

Open Magicsmx opened 9 months ago

Magicsmx commented 9 months ago

Is there an existing issue for this?

Current Behavior

求问 在做基于ChatGLM2-6B的Ptuning v2的微调任务时 损失函数的公式是什么样子的

Expected Behavior

Steps To Reproduce

RT

Environment

- OS:
- Python:
- Transformers:
- PyTorch:
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :

Anything else?

No response

wizardforcel commented 9 months ago

modeling_chatglm.py的960行左右,shift_logits是模型输出的每个单词的logit,截掉最后一个词,shift_labels是真实标签,截掉第一个词。

loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

里面的运算大概是这样,手动验证 loss 是一样的:

          mask = shift_labels != -100
          shift_logits = shift_logits[mask]
          shift_labels = shift_labels[mask]
          shift_probs = F.softmax(shift_logits, -1)
          shift_onehots = F.one_hot(shift_labels, shift_probs.shape[-1])
          loss = (-shift_onehots * torch.log(shift_probs)).sum(-1).mean()