yangjianxin1 / GPT2-chitchat

GPT2 for Chinese chitchat/用于中文闲聊的GPT2模型(实现了DialoGPT的MMI思想)
2.99k stars 680 forks source link

训练代码逻辑问题 #119

Closed wujohns closed 1 year ago

wujohns commented 1 year ago

在 train.py 中:

  1. caculate_loss 方法并没有被应用
  2. collcate_fn 中的没有做偏移
    def collate_fn(batch):
    input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0)
    labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)
    return input_ids, labels

上述两者结合导致训练的loss计算逻辑错误,其梯度处理也会受到原理机制上的影响而变得偏差极大

DarrenRuan commented 1 year ago

这里应该没什么问题,因为用的是huggingface,GPT2LMHeadModel 内部自己会shift。

https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel

Note that the labels are shifted inside the model, i.e. you can set labels = input_ids. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]

wujohns commented 1 year ago

这里应该没什么问题,因为用的是huggingface,GPT2LMHeadModel 内部自己会shift。

https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel

Note that the labels are shifted inside the model, i.e. you can set labels = input_ids. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]

感谢,后来也是看这个文档发现没什么问题,当时忘记close issue了