unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18k stars 1.25k forks source link

Issue With Mistral Small #1044

Open DaddyCodesAlot opened 1 month ago

DaddyCodesAlot commented 1 month ago

Attempting to generate with Mistral Small causes this error:


RuntimeError Traceback (most recent call last) Cell In[5], line 77 5 inputs = tokenizer( 6 [ 7 alpaca_prompt.format( (...) 73 74 ], return_tensors = "pt").to("cuda") 76 start = time.time() ---> 77 outputs = model.generate(**inputs, max_new_tokens = 1024, use_cache = False) 78 #tokenizer.batch_decode(outputs) 79 end = time.time()

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py:1407, in _wrap_fast_inference.._fast_generate(*args, *kwargs) 1400 # Set pad token 1401 # old_pad_token_id = getattr(model.config, "pad_token_id", None) 1402 # old_eos_token_id = getattr(model.config, "eos_token_id", None) 1403 # model.config.pad_token_id = old_eos_token_id 1404 1405 # Autocasted 1406 with torch.autocast(device_type = device_type, dtype = dtype): -> 1407 output = generate(args, **kwargs) 1408 pass 1410 # Revert 1411 # model.config.pad_token_id = old_pad_token_id 1412 1413 # Unset a flag for generation!

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, kwargs) 2016 input_ids, model_kwargs = self._expand_inputs_for_generation( 2017 input_ids=input_ids, 2018 expand_size=generation_config.num_return_sequences, 2019 is_encoder_decoder=self.config.is_encoder_decoder, 2020 model_kwargs, 2021 ) 2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False) -> 2024 result = self._sample( 2025 input_ids, 2026 logits_processor=prepared_logits_processor, 2027 logits_warper=prepared_logits_warper, 2028 stopping_criteria=prepared_stopping_criteria, 2029 generation_config=generation_config, 2030 synced_gpus=synced_gpus, 2031 streamer=streamer, 2032 **model_kwargs, 2033 ) 2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): 2036 # 11. prepare logits warper 2037 prepared_logits_warper = ( 2038 self._get_logits_warper(generation_config, device=input_ids.device) 2039 if generation_config.do_sample 2040 else None 2041 )

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, model_kwargs) 2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) 2981 # forward pass to get next token -> 2982 outputs = self(model_inputs, return_dict=True) 2984 if synced_gpus and this_peer_finished: 2985 continue # don't waste resources running the code we don't need

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/unsloth/models/mistral.py:220, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs) 212 outputs = LlamaModel_fast_forward_inference( 213 self, 214 input_ids, (...) 217 attention_mask = attention_mask, 218 ) 219 else: --> 220 outputs = self.model( 221 input_ids=input_ids, 222 causal_mask=causal_mask, 223 attention_mask=attention_mask, 224 position_ids=position_ids, 225 past_key_values=past_key_values, 226 inputs_embeds=inputs_embeds, 227 use_cache=use_cache, 228 output_attentions=output_attentions, 229 output_hidden_states=output_hidden_states, 230 return_dict=return_dict, 231 ) 232 pass 234 hidden_states = outputs[0]

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py:801, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs) 798 hidden_states = layer_outputs[0] 800 else: --> 801 layer_outputs = decoder_layer( 802 hidden_states, 803 causal_mask=mask, 804 attention_mask=attention_mask, 805 position_ids=position_ids, 806 past_key_value=past_key_value, 807 output_attentions=output_attentions, 808 use_cache=use_cache, 809 padding_mask=padding_mask, 810 ) 811 hidden_states = layer_outputs[0] 812 pass

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py:486, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs) 484 residual = hidden_states 485 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) --> 486 hidden_states, self_attn_weights, present_key_value = self.self_attn( 487 hidden_states=hidden_states, 488 causal_mask=causal_mask, 489 attention_mask=attention_mask, 490 position_ids=position_ids, 491 past_key_value=past_key_value, 492 output_attentions=output_attentions, 493 use_cache=use_cache, 494 padding_mask=padding_mask, 495 ) 496 hidden_states = residual + hidden_states 498 # Fully Connected

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/unsloth/models/mistral.py:90, in MistralAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs) 88 else: 89 cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) ---> 90 Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) 91 pass 93 if past_key_value is not None:

File /usr/local/lib/python3.10/dist-packages/unsloth/kernels/rope_embedding.py:178, in inplace_rope_embedding(Q, K, cos, sin, position_ids) 177 def inplace_rope_embedding(Q, K, cos, sin, position_ids): --> 178 Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids) 179 K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids) 180 return Q, K

File /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, *kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 )

File /usr/local/lib/python3.10/dist-packages/unsloth/kernels/rope_embedding.py:154, in Slow_RoPE_Embedding.forward(ctx, Q, cos, sin, position_ids) 152 half = Q.shape[-1]//2 153 RHQ = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1) --> 154 Q *= cos 155 Q.addcmul(RH_Q, sin) 156 # RH_Q *= sin 157 # Q += RH_Q

RuntimeError: The size of tensor a (48) must match the size of tensor b (1460) at non-singleton dimension 1

danielhanchen commented 1 month ago

Will investigate this! Sorry on the issue!