hiyouga / LLaMA-Factory

Unify Efficient Fine-Tuning of 100+ LLMs
Apache License 2.0
25.26k stars 3.13k forks source link

【问题】为什么要把可训练参数精度强行转换为全精度? #4549

Closed LaniakeaS closed 2 days ago

LaniakeaS commented 3 days ago

我在尝试全参微调,发现显存不够用。排查后发现llama-factory会强制把精度设置在fp32。由于我使用了deepspeed,所以无法使用pure bf16参数。

想问一下这个步骤的必要性是什么?能否在使用deepspeed的情况下也支持bf16和fp16?

hiyouga commented 2 days ago

deepspeed 不支持 pure_bf16

LaniakeaS commented 2 days ago

你误会我的意思了,我不是要它支持pure-bf16。举个例子来说,adapter.py中的_setup_full_tuning的param.data.to(torch.float32)这行代码让我发生了OOM的问题,我把这行代码注释掉就可以训练了。所以我想要的是是否可以提供某个参数来让我选择是否使用这里的float32转换。

hiyouga commented 2 days ago

你用的是 deepspeed stage 多少?

LaniakeaS commented 2 days ago

zero-3

hiyouga commented 2 days ago

理论上 zero3 不会走到那个逻辑,你用的是最新代码吗

LaniakeaS commented 2 days ago

抱歉搞错了,我是在发现OOM之后,改回了zero-3。之前在zero-2的情况下,会触发fp32转换导致的OOM,然后我把cast to fp32那行代码注释掉就可以在zero-2的条件下训练了。

hiyouga commented 1 day ago

试试用 pure_bf16: truebf16: true 再跑下