yuhuixu1993 / qa-lora

Official PyTorch implementation of QA-LoRA
MIT License
118 stars 11 forks source link

RuntimeError: self and mat2 must have the same dtype #10

Open M-Elfeki opened 1 year ago

M-Elfeki commented 1 year ago

Whenever I follow the installation instructions, apply autogptq to llama2-hf, then try to run qalora.py on the checkpoints produced by autogptq. I install autogptq using pip install autogptq[triton], and follow all other installation. peft_integration is no longer a branch in autogptq and thus can't be installed. I get the following error: "loading base model checkpoints... QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend. adding LoRA modules... trainable params: 17891328.0 || all params: 1216401408 || trainable: 1.4708407835055712 loaded model Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Using pad_token, but it is not set yet. torch.float16 263512064 0.21663102954794175 torch.int32 809500672 0.665483626567893 torch.float32 143396864 0.11788534388416533 0%| | 0/10000 [00:00<?, ?it/s]Traceback (most recent call last): File "qalora.py", line 793, in train() File "qalora.py", line 755, in train train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1591, in train return inner_training_loop( File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 2776, in training_step loss = self.compute_loss(model, inputs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 2801, in compute_loss outputs = model(inputs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/peft/peft_model.py", line 948, in forward return self.base_model( File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/peft/tuners/tuners_utils.py", line 106, in forward return self.model.forward(*args, kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1038, in forward outputs = self.model( File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 921, in forward layer_outputs = torch.utils.checkpoint.checkpoint( File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 107, in forward outputs = run_function(args) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 917, in custom_forward return module(inputs, past_key_value, output_attentions, padding_mask=padding_mask) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 635, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/auto_gptq/nn_modules/fused_llama_attn.py", line 53, in forward qkv_states = self.qkv_proj(hidden_states) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/peft/tuners/lora/layer.py", line 251, in forward result = self._linear(x) File "/home/mohamed/miniconda3/envs/qalora/lib/python3.8/site-packages/peft/tuners/lora/layer.py", line 239, in _linear return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) RuntimeError: self and mat2 must have the same dtype"

Can you help me setup my environment and mention the details instructions to run the qalora on llama2-hf given the latest packages, or release your conda environment yml file? Thank you.

M-Elfeki commented 1 year ago

input.dtype, self.weight.dtype, self.bias.dtype (torch.float32, torch.int32, torch.float16)

yuhuixu1993 commented 1 year ago

@M-Elfeki I checked my dependencies that, I used the auto-gotq==0.3.0.dev0, You can try 0.3.0

jianyuheng commented 1 year ago

@yuhuixu1993 same problem here.

@M-Elfeki我检查了我的依赖项,我使用了 auto-gotq==0.3.0.dev0,你可以尝试 0.3.0

jianyuheng commented 1 year ago

The error no longer appears when I modify these two lines of code. @yuhuixu1993

https://github.com/yuhuixu1993/qa-lora/blob/8791e08929fee2bf015a9bbc0ebaaefd0c9cf2a5/qalora.py#L300 model.config.torch_dtype=torch.float16

https://github.com/yuhuixu1993/qa-lora/blob/8791e08929fee2bf015a9bbc0ebaaefd0c9cf2a5/qalora.py#L340 module = module.to(torch.float16)

shawnricecake commented 1 year ago

this error still raises when resume the ckpt

StiphyJay commented 12 months ago

I used auto-gotq==0.3.0, pytorch=2.0.1-cu118 and fixed the above problems.