lyogavin / airllm

AirLLM 70B inference with single 4GB GPU
Apache License 2.0
4.05k stars 335 forks source link

unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit #180

Open kendiyang opened 2 weeks ago

kendiyang commented 2 weeks ago

new version of transfomer, no need to use BetterTransformer, try setting attn impl to sdpa... attn imp: <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>. running layers(cuda:0): 1%|▊ | 1/129 [00:03<07:01, 3.30s/it] Traceback (most recent call last): File "/root/test.py", line 18, in generation_output = model.generate( File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1989, in generate result = self._sample( File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2932, in _sample outputs = self(model_inputs, return_dict=True) File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 369, in call return self.forward(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 569, in forward new_seq = layer(seq, kwargs)[0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 677, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 565, in forward key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) RuntimeError: shape '[1, 9, 8, 128]' is invalid for input of size 18432

1272870698 commented 2 weeks ago

Have you found a solution to the same problem?

tripathiarpan20 commented 2 weeks ago

This hacky fix seems to work: https://github.com/lyogavin/airllm/issues/172#issuecomment-2322887273

But it's definitely a bug in AirLLMLlama2 as it works for Llama 3.1 70B model without any tweaks