Open kendiyang opened 2 weeks ago
Have you found a solution to the same problem?
This hacky fix seems to work: https://github.com/lyogavin/airllm/issues/172#issuecomment-2322887273
But it's definitely a bug in AirLLMLlama2
as it works for Llama 3.1 70B model without any tweaks
new version of transfomer, no need to use BetterTransformer, try setting attn impl to sdpa... attn imp: <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>. running layers(cuda:0): 1%|▊ | 1/129 [00:03<07:01, 3.30s/it] Traceback (most recent call last): File "/root/test.py", line 18, in
generation_output = model.generate(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2932, in _sample
outputs = self(model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 369, in call
return self.forward(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 569, in forward
new_seq = layer(seq, kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 565, in forward
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[1, 9, 8, 128]' is invalid for input of size 18432