ymcui / Chinese-LLaMA-Alpaca

中文LLaMA&Alpaca大语言模型+本地CPU/GPU训练部署 (Chinese LLaMA & Alpaca LLMs)
https://github.com/ymcui/Chinese-LLaMA-Alpaca/wiki
Apache License 2.0
17.98k stars 1.84k forks source link

Fixbug: "RuntimeError: expected scalar type Half but found Float" #875

Open lilbedwin opened 7 months ago

lilbedwin commented 7 months ago

While doing inference using lora model with the base model, I got the following errors:

Traceback (most recent call last): File "/root/autodl-tmp/Chinese-LLaMA-Alpaca/scripts/inference/inference_hf.py", line 124, in generation_output = model.generate( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/peft/peft_model.py", line 581, in generate outputs = self.base_model.generate(kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, *kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/transformers/generation/utils.py", line 1572, in generate return self.sample( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/transformers/generation/utils.py", line 2619, in sample outputs = self( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = module._old_forward(*args, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward outputs = self.model( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 578, in forward layer_outputs = decoder_layer( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = module._old_forward(*args, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = module._old_forward(args, kwargs) File "/root/autodl-tmp/Chinese-LLaMA-Alpaca/scripts/inference/patches.py", line 43, in xformers_forward query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = module._old_forward(*args, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/peft/tuners/lora.py", line 358, in forward result += self.lora_B(self.lora_A(self.lora_dropout(x))) self.scaling File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = module._old_forward(*args, kwargs) File "/root/miniconda3/envs/p310/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: expected scalar type Half but found Float

It seems a bug of peft, while loading lora model, the dtype of param maybe wrong, but we fix it by convert the param type forcely.