bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.35k stars 637 forks source link

RuntimeError: mat1 and mat2 shapes cannot be multiplied (when using peft) #1268

Open wizardforcel opened 5 months ago

wizardforcel commented 5 months ago

System Info

Ubuntu 20.04
Python 3.10.14
torch                    2.3.0
transformers             4.42.3
bitsandbytes             0.42.0
CUDA Version: 12.4
GPU 3090
torch.cuda.is_available(): True

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
import bitsandbytes as bnb
import peft
base_path = '/data/Qwen1.5-0.5B-Chat/'
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)
llm = AutoModelForCausalLM.from_pretrained(
        base_path, trust_remote_code=True, quantization_config=nf4_config)
tgt_mods = {
    name.split('.')[-1]
    for name, mod in llm.named_modules()
    if isinstance(mod, bnb.nn.Linear4bit)
}
lora_config = peft.LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=tgt_mods,
    lora_dropout=0.1,
    task_type=peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    bias="none"
)
llm = peft.get_peft_model(llm, lora_config)
llm.cuda() # !!!
llm.chat(tok, '你好', max_length=4096)

'''
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[63], line 1
----> 1 llm.chat(tok, '你好', max_length=4096)

File ~/gpt_test/src/qwen_chat.py:8, in chat(model, tok, ques, history, **kw)
      3 def chat(model, tok, ques, history=[], **kw):
      4         iids = tok.apply_chat_template(
      5                 history + [{'role': 'user', 'content': ques}],
      6                 add_generation_prompt=1,
      7         )
----> 8         oids = model.generate(
      9                 inputs=torch.tensor([iids]).to(model.device),
     10                 **(model.generation_config.to_dict() | kw),
     11         )
     12         oids = oids[0][len(iids):].tolist()
     13         if oids[-1] == tok.eos_token_id:

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/peft/peft_model.py:1190, in PeftModelForCausalLM.generate(self, *args, **kwargs)
   1188     with self._enable_peft_forward_hooks(*args, **kwargs):
   1189         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1190         outputs = self.base_model.generate(*args, **kwargs)
   1191 else:
   1192     outputs = self.base_model.generate(**kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/generation/utils.py:1914, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1906     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1907         input_ids=input_ids,
   1908         expand_size=generation_config.num_return_sequences,
   1909         is_encoder_decoder=self.config.is_encoder_decoder,
   1910         **model_kwargs,
   1911     )
   1913     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1914     result = self._sample(
   1915         input_ids,
   1916         logits_processor=prepared_logits_processor,
   1917         logits_warper=prepared_logits_warper,
   1918         stopping_criteria=prepared_stopping_criteria,
   1919         generation_config=generation_config,
   1920         synced_gpus=synced_gpus,
   1921         streamer=streamer,
   1922         **model_kwargs,
   1923     )
   1925 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   1926     # 11. prepare logits warper
   1927     prepared_logits_warper = (
   1928         self._get_logits_warper(generation_config, device=input_ids.device)
   1929         if generation_config.do_sample
   1930         else None
   1931     )

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/generation/utils.py:2651, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2648 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2650 # forward pass to get next token
-> 2651 outputs = self(
   2652     **model_inputs,
   2653     return_dict=True,
   2654     output_attentions=output_attentions,
   2655     output_hidden_states=output_hidden_states,
   2656 )
   2658 if synced_gpus and this_peer_finished:
   2659     continue  # don't waste resources running the code we don't need

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1221, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1218 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1220 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1221 outputs = self.model(
   1222     input_ids=input_ids,
   1223     attention_mask=attention_mask,
   1224     position_ids=position_ids,
   1225     past_key_values=past_key_values,
   1226     inputs_embeds=inputs_embeds,
   1227     use_cache=use_cache,
   1228     output_attentions=output_attentions,
   1229     output_hidden_states=output_hidden_states,
   1230     return_dict=return_dict,
   1231     cache_position=cache_position,
   1232 )
   1234 hidden_states = outputs[0]
   1235 logits = self.lm_head(hidden_states)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1023, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1012     layer_outputs = self._gradient_checkpointing_func(
   1013         decoder_layer.__call__,
   1014         hidden_states,
   (...)
   1020         cache_position,
   1021     )
   1022 else:
-> 1023     layer_outputs = decoder_layer(
   1024         hidden_states,
   1025         attention_mask=causal_mask,
   1026         position_ids=position_ids,
   1027         past_key_value=past_key_values,
   1028         output_attentions=output_attentions,
   1029         use_cache=use_cache,
   1030         cache_position=cache_position,
   1031     )
   1033 hidden_states = layer_outputs[0]
   1035 if use_cache:

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:763, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    760 hidden_states = self.input_layernorm(hidden_states)
    762 # Self Attention
--> 763 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    764     hidden_states=hidden_states,
    765     attention_mask=attention_mask,
    766     position_ids=position_ids,
    767     past_key_value=past_key_value,
    768     output_attentions=output_attentions,
    769     use_cache=use_cache,
    770     cache_position=cache_position,
    771 )
    772 hidden_states = residual + hidden_states
    774 # Fully Connected

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:650, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    639     return super().forward(
    640         hidden_states=hidden_states,
    641         attention_mask=attention_mask,
   (...)
    645         use_cache=use_cache,
    646     )
    648 bsz, q_len, _ = hidden_states.size()
--> 650 query_states = self.q_proj(hidden_states)
    651 key_states = self.k_proj(hidden_states)
    652 value_states = self.v_proj(hidden_states)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/peft/tuners/lora/bnb.py:452, in Linear4bit.forward(self, x, *args, **kwargs)
    450     result = self.base_layer(x, *args, **kwargs)
    451 else:
--> 452     result = self.base_layer(x, *args, **kwargs)
    453     # As per Tim Dettmers, for 4bit, we need to defensively clone here.
    454     # The reason is that in some cases, an error can occur that backprop
    455     # does not work on a manipulated view. This issue may be solved with
    456     # newer PyTorch versions but this would need extensive testing to be
    457     # sure.
    458     result = result.clone()

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:256, in Linear4bit.forward(self, x)
    253     x = x.to(self.compute_dtype)
    255 bias = None if self.bias is None else self.bias.to(self.compute_dtype)
--> 256 out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
    258 out = out.to(inp_dtype)
    260 return out

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:577, in matmul_4bit(A, B, quant_state, out, bias)
    575         return out
    576 else:
--> 577     return MatMul4Bit.apply(A, B, out, bias, quant_state)

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         "https://pytorch.org/docs/master/notes/extending.func.html"
    606     )

File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:516, in MatMul4Bit.forward(ctx, A, B, out, bias, quant_state)
    511         return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
    514 # 1. Dequantize
    515 # 2. MatmulnN
--> 516 output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
    518 # 3. Save state
    519 ctx.state = quant_state

RuntimeError: mat1 and mat2 shapes cannot be multiplied (19x1024 and 1x524288)
'''

Expected behavior

.

AbhayUrmaliya2004 commented 3 weeks ago

Did you got any way to overcome this