ymcui / Chinese-LLaMA-Alpaca-2

中文LLaMA-2 & Alpaca-2大模型二期项目 + 64K超长上下文模型 (Chinese LLaMA-2 & Alpaca-2 LLMs with 64K long context models)
Apache License 2.0
7.04k stars 581 forks source link

关于训练过程中'eval_loss'都是nan的问题,解决方法 #424

Closed 5663015 closed 8 months ago

5663015 commented 9 months ago

提交前必须检查以下项目

问题类型

模型训练与精调

基础模型

None

操作系统

Linux

详细描述问题

看到很多人遇到了SFT时eval_loss为nan的情况,这几天我也遇到了,有说是训练参数问题、数值精度问题,但依然都会出现eval_loss=nan。经过debug,可能是分词部分有问题:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/0189e8be7706d2aceb90d238149d5fda6b6aea8d/scripts/training/build_dataset.py#L45 这里会将input和target分词后直接拼接起来,如果超过max_seq_length则截断。但是如果在target很短的情况下,target可能会丢失(我的数据中有的target就很短,只有一个词)。这样训练样本中的labels可能全变为了IGNORE_INDEX,导致训练loss不稳定,eval_loss出现nan。我的修改是,使用max_input_lengthmax_target_length参数而不是max_seq_length。在input_ids = torch.LongTensor(s + t)[:max_seq_length]前面加上两行:

if len(s) > max_input_length:
    s = s[:max_input_length]

这样我的训练loss也平稳下来了,eval_loss也正常了。供大家参考。

依赖情况(代码类问题务必提供)

# 请在此处粘贴依赖情况(请粘贴在本代码块里)

运行日志或截图

# 请在此处粘贴运行日志(请粘贴在本代码块里)
tkone2018 commented 9 months ago

@5663015 你好,请问可以贴完整的代码吗,max_input_length和max_target_length参数你是如何设定的呀

yusufcakmakk commented 9 months ago

How do you decide on max_input_length? Is it half of max_seq_length or a different calculation? Maybe we can set max_seq_length to higher to solve this issue, right?

yusufcakmakk commented 9 months ago

I've tried with longer sequence lengths, and it solved the problem. I tried it with 1024 max_seq_length

5663015 commented 9 months ago

相关代码:

max_seq_length = max_input_length + max_target_length
if len(s) > max_input_length:
    s = s[:max_input_length]

max_input_lengthmax_target_length具体值需要根据你数据的实际长度情况来确定,max_input_length是你输入数据的最大长度,max_target_length是你标签的最大长度。 @tkone2018

5663015 commented 9 months ago

I've tried with longer sequence lengths, and it solved the problem. I tried it with 1024 max_seq_length

max_input_length is max length of your input prompt, max_target_length is max length of your labels. Yes, if you set longer max_seq_length to avoid labels missing, it's also a solution.

DandinPower commented 9 months ago

I implemented your modifications and it worked perfectly! Thanks a lot!

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your consideration.

github-actions[bot] commented 8 months ago

Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.

liumazeze commented 4 months ago

按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖

DandinPower commented 4 months ago

按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖

這裡的原因是因為 target 的 token 因為超過 max length 而消失導致沒辦法算 loss,你可能要觀察一下你的資料集有沒有發生這種情況

liumazeze commented 4 months ago

按照大佬的方法调试了,还是nan,是loss太大的缘故嘛😖

這裡的原因是因為 target 的 token 因為超過 max length 而消失導致沒辦法算 loss,你可能要觀察一下你的資料集有沒有發生這種情況

感谢回复,我明白您的意思。在我的代码中已经按照 @5663015 的方法进行了截断👀

DandinPower commented 4 months ago

喔喔瞭解,那可能是別的原因。