Closed 5663015 closed 8 months ago
@5663015 你好,请问可以贴完整的代码吗,max_input_length和max_target_length参数你是如何设定的呀
How do you decide on max_input_length? Is it half of max_seq_length or a different calculation? Maybe we can set max_seq_length to higher to solve this issue, right?
I've tried with longer sequence lengths, and it solved the problem. I tried it with 1024 max_seq_length
相关代码:
在class DataTrainingArguments:
处添加两个参数,max_input_length
,max_target_length
。
https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/run_clm_sft_with_peft.py#L199
在分词训练集和验证集处的函数build_instruction_dataset
传入这两个参数。
https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/run_clm_sft_with_peft.py#L353
https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/run_clm_sft_with_peft.py#L366
最后在build_instruction_dataset
内修改:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/build_dataset.py#L45
max_seq_length = max_input_length + max_target_length
if len(s) > max_input_length:
s = s[:max_input_length]
max_input_length
和max_target_length
具体值需要根据你数据的实际长度情况来确定,max_input_length
是你输入数据的最大长度,max_target_length
是你标签的最大长度。
@tkone2018
I've tried with longer sequence lengths, and it solved the problem. I tried it with 1024 max_seq_length
max_input_length
is max length of your input prompt, max_target_length
is max length of your labels.
Yes, if you set longer max_seq_length
to avoid labels missing, it's also a solution.
I implemented your modifications and it worked perfectly! Thanks a lot!
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your consideration.
Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.
按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖
按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖
這裡的原因是因為 target 的 token 因為超過 max length 而消失導致沒辦法算 loss,你可能要觀察一下你的資料集有沒有發生這種情況
按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖
這裡的原因是因為 target 的 token 因為超過 max length 而消失導致沒辦法算 loss,你可能要觀察一下你的資料集有沒有發生這種情況
感谢回复,我明白您的意思。在我的代码中已经按照 @5663015 的方法进行了截断👀
喔喔瞭解,那可能是別的原因。
提交前必须检查以下项目
问题类型
模型训练与精调
基础模型
None
操作系统
Linux
详细描述问题
看到很多人遇到了SFT时eval_loss为nan的情况,这几天我也遇到了,有说是训练参数问题、数值精度问题,但依然都会出现eval_loss=nan。经过debug,可能是分词部分有问题:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/build_dataset.py#L45 这里会将input和target分词后直接拼接起来,如果超过
max_seq_length
则截断。但是如果在target很短的情况下,target可能会丢失(我的数据中有的target就很短,只有一个词)。这样训练样本中的labels
可能全变为了IGNORE_INDEX
,导致训练loss不稳定,eval_loss出现nan。我的修改是,使用max_input_length
和max_target_length
参数而不是max_seq_length
。在input_ids = torch.LongTensor(s + t)[:max_seq_length]
前面加上两行:这样我的训练loss也平稳下来了,eval_loss也正常了。供大家参考。
依赖情况(代码类问题务必提供)
运行日志或截图