yangjianxin1 / Firefly

Firefly: 大模型训练工具,支持训练Qwen2.5、Qwen2、Yi1.5、Phi-3、Llama3、Gemma、MiniCPM、Yi、Deepseek、Orion、Xverse、Mixtral-8x7B、Zephyr、Mistral、Baichuan2、Llma2、Llama、Qwen、Baichuan、ChatGLM2、InternLM、Ziya2、Vicuna、Bloom等大模型
5.9k stars 527 forks source link

Qwen2-7B-Instruct 训练loss 0 或推理 probability tensor contains either `inf`, `nan` or element < 0 #272

Closed WinterStraw closed 5 months ago

WinterStraw commented 5 months ago

训练的时候如果出现loss为0,把fp16改成bf16后可以解决。但是推理部分不知道怎么解决。 Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]Loading checkpoint shards: 100%|███████████████████████████████████████████████| 4/4 [00:02<00:00, 1.54it/s] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. User:hello Traceback (most recent call last): File "/home/ecs-user/LLM-train/Firefly/script/chat/chat.py", line 153, in main() File "/home/ecs-user/LLM-train/Firefly/script/chat/chat.py", line 136, in main outputs = model.generate( File "/home/ecs-user/miniconda3/envs/firefly/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/ecs-user/miniconda3/envs/firefly/lib/python3.9/site-packages/transformers/generation/utils.py", line 1520, in generate return self.sample( File "/home/ecs-user/miniconda3/envs/firefly/lib/python3.9/site-packages/transformers/generation/utils.py", line 2653, in sample next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) RuntimeError: probability tensor contains either inf, nan or element < 0 非常感谢!

WinterStraw commented 5 months ago

Qwen 1.5没有任何问题

WinterStraw commented 5 months ago

推理的时候修改LLM-train/Firefly/component/utils.py 把 torch_dtype=torch.float16, 改成 torch_dtype = torch.float32, 即可解决。 A10显卡可能出现这个不兼容问题

joewale commented 5 months ago

Hi, 我也遇到这个问题,是这个配置文件qwen2-7b-sft-qlora.json里面的"fp16": true,改为"bf16": true吗? @WinterStraw

WinterStraw commented 5 months ago

Hi, 我也遇到这个问题,是这个配置文件qwen2-7b-sft-qlora.json里面的"fp16": true,改为"bf16": true吗? @WinterStraw

是的,训练的时候改fp或bf,推理的时候改 torch_dtype

joewale commented 5 months ago

好的,感谢,我试下

dreamerlvtx commented 4 months ago

你好,感谢你的方法,我在应用的时候发现把 torch_dtype=torch.float16,改成 torch_dtype = torch.float32,推理速度会很慢,改成torch_dtype=torch.bfloat16,可以解决你的问题并提高速度