Closed zhengyanzhao1997 closed 1 year ago
感谢,是会有这个问题,我们之后会调整
generate_and_tokenize_prompt函数中: *"attention_mask": [1] (len(full_tokens))** 这行,看起来把padding的token也设置为1了,这个是合理的吗?
构建dataset时,将instruction部分的label设置为-100可能是无效的,原因是Trainer部分设置的transformers.DataCollatorForLanguageModeling会将label进行重制。 将
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
替换为data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True)
能解决这个问题
请问一下,如果不替换的话,可能导致训练会出什么错误呢?我看了很多类似的repo,貌似都和作者的写法差不多,不知道这个bug的影响是啥呢?
构建dataset时,将instruction部分的label设置为-100可能是无效的,原因是Trainer部分设置的transformers.DataCollatorForLanguageModeling会将label进行重制。 将
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
替换为data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True)
能解决这个问题请问一下,如果不替换的话,可能导致训练会出什么错误呢?我看了很多类似的repo,貌似都和作者的写法差不多,不知道这个bug的影响是啥呢?
没有大bug,最终性能可能会有区别,以实验结果为准哈
generate_and_tokenize_prompt函数中: *"attention_mask": [1] (len(full_tokens))** 这行,看起来把padding的token也设置为1了,这个是合理的吗?
只要最后把padding id 的label 忽视掉,就没问题
构建dataset时,将instruction部分的label设置为-100可能是无效的,原因是Trainer部分设置的transformers.DataCollatorForLanguageModeling会将label进行重制。 将
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
替换为data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True)
能解决这个问题请问一下,如果不替换的话,可能导致训练会出什么错误呢?我看了很多类似的repo,貌似都和作者的写法差不多,不知道这个bug的影响是啥呢?
没有大bug,最终性能可能会有区别,以实验结果为准哈
我大概理解了,作者原来的写法更像是语言模型自回归的训练(prmopt + input + output),因为labels并没有生效;您这边的写法就更像是QA的训练(labels部分prmopt + input的部分被mask掉了),只训练A部分对应的loss,我这样理解对吗?
@songbaipu 是的,我这种写法相当于是把prompt和input的生成也放进loss里面的;label起效果相当于只把output的生成放进loss里面。对最终的性能可能影响不大,不过改了之后应该会稍微快一点。
构建dataset时,将instruction部分的label设置为-100可能是无效的,原因是Trainer部分设置的transformers.DataCollatorForLanguageModeling会将label进行重制。 将
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
替换为data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True)
能解决这个问题