unslothai / unsloth

Finetune Llama 3, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
12.85k stars 837 forks source link

Error message when using ORPO fine-tuning #601

Open MRQJsfhf opened 1 month ago

MRQJsfhf commented 1 month ago

When using ORPO to fine-tune mistral-7b-instruct-v0.3-bnb-4bit, after clicking orpo_trainer.train() to start, the following error message appears:

`-------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) Cell In[15], line 1 ----> 1 orpo_trainer.train()

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) 1883 hf_hub_utils.enable_progress_bars() 1884 else: -> 1885 return inner_training_loop( 1886 args=args, 1887 resume_from_checkpoint=resume_from_checkpoint, 1888 trial=trial, 1889 ignore_keys_for_eval=ignore_keys_for_eval, 1890 )

File :348, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File /usr/local/lib/python3.10/site-packages/transformers/trainer.py:3238, in Trainer.training_step(self, model, inputs) 3235 return loss_mb.reduce_mean().detach().to(self.args.device) 3237 with self.compute_loss_context_manager(): -> 3238 loss = self.compute_loss(model, inputs) 3240 del inputs 3241 torch.cuda.empty_cache()

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:786, in ORPOTrainer.compute_loss(self, model, inputs, return_outputs) 783 compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext 785 with compute_loss_context_manager(): --> 786 loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") 788 # force log the metrics 789 self.store_metrics(metrics, train_eval="train")

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:746, in ORPOTrainer.get_batch_loss_metrics(self, model, batch, train_eval) 737 """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" 738 metrics = {} 740 ( 741 policy_chosen_logps, 742 policy_rejected_logps, 743 policy_chosen_logits, 744 policy_rejected_logits, 745 policy_nll_loss, --> 746 ) = self.concatenated_forward(model, batch) 748 losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( 749 policy_chosen_logps, policy_rejected_logps 750 ) 751 # full ORPO loss

File /usr/local/lib/python3.10/site-packages/trl/trainer/orpo_trainer.py:686, in ORPOTrainer.concatenated_forward(self, model, batch) 676 len_chosen = batch["chosen_labels"].shape[0] 678 model_kwargs = ( 679 { 680 "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), (...) 683 else {} 684 ) --> 686 outputs = model( 687 concatenated_batch["concatenated_input_ids"], 688 attention_mask=concatenated_batch["concatenated_attention_mask"], 689 use_cache=False, 690 **model_kwargs, 691 ) 692 all_logits = outputs.logits 694 def cross_entropy_loss(logits, labels):

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, kwargs) 821 def forward(*args, *kwargs): --> 822 return model_forward(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, kwargs) 809 def call(self, *args, *kwargs): --> 810 return convert_to_fp32(self.model_forward(args, kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, kwargs) 13 @functools.wraps(func) 14 def decorate_autocast(*args, *kwargs): 15 with autocast_instance: ---> 16 return func(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, kwargs) 821 def forward(*args, *kwargs): --> 822 return model_forward(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, kwargs) 809 def call(self, *args, *kwargs): --> 810 return convert_to_fp32(self.model_forward(args, kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, kwargs) 13 @functools.wraps(func) 14 def decorate_autocast(*args, *kwargs): 15 with autocast_instance: ---> 16 return func(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32..forward(*args, kwargs) 821 def forward(*args, *kwargs): --> 822 return model_forward(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.call(self, *args, kwargs) 809 def call(self, *args, *kwargs): --> 810 return convert_to_fp32(self.model_forward(args, kwargs))

File /usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, kwargs) 13 @functools.wraps(func) 14 def decorate_autocast(*args, *kwargs): 15 with autocast_instance: ---> 16 return func(args, kwargs)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:883, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, kwargs) 870 def PeftModelForCausalLM_fast_forward( 871 self, 872 input_ids=None, (...) 881 kwargs, 882 ): --> 883 return self.base_model( 884 input_ids=input_ids, 885 causal_mask=causal_mask, 886 attention_mask=attention_mask, 887 inputs_embeds=inputs_embeds, 888 labels=labels, 889 output_attentions=output_attentions, 890 output_hidden_states=output_hidden_states, 891 return_dict=return_dict, 892 **kwargs, 893 )

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /usr/local/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:179, in BaseTuner.forward(self, *args, kwargs) 178 def forward(self, *args: Any, *kwargs: Any): --> 179 return self.model.forward(args, kwargs)

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, kwargs) 164 output = module._old_forward(*args, *kwargs) 165 else: --> 166 output = module._old_forward(args, kwargs) 167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/mistral.py:213, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs) 205 outputs = LlamaModel_fast_forward_inference( 206 self, 207 input_ids, (...) 210 attention_mask = attention_mask, 211 ) 212 else: --> 213 outputs = self.model( 214 input_ids=input_ids, 215 causal_mask=causal_mask, 216 attention_mask=attention_mask, 217 position_ids=position_ids, 218 past_key_values=past_key_values, 219 inputs_embeds=inputs_embeds, 220 use_cache=use_cache, 221 output_attentions=output_attentions, 222 output_hidden_states=output_hidden_states, 223 return_dict=return_dict, 224 ) 225 pass 227 hidden_states = outputs[0]

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, kwargs) 164 output = module._old_forward(*args, *kwargs) 165 else: --> 166 output = module._old_forward(args, kwargs) 167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:651, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs) 648 past_key_value = past_key_values[idx] if past_key_values is not None else None 650 if offloaded_gradient_checkpointing: --> 651 hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply( 652 decoder_layer, 653 hidden_states, 654 causal_mask, 655 attention_mask, 656 position_ids, 657 past_key_values, 658 output_attentions, 659 use_cache, 660 )[0] 662 elif gradient_checkpointing: 663 def create_custom_forward(module):

File /usr/local/lib/python3.10/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, *kwargs) 550 if not torch._C._are_functorch_transforms_active(): 551 # See NOTE: [functorch vjp and autograd interaction] 552 args = _functorch.utils.unwrap_dead_wrappers(args) --> 553 return super().apply(args, **kwargs) # type: ignore[misc] 555 if not is_setup_ctx_defined: 556 raise RuntimeError( 557 "In order to use an autograd.Function with functorch transforms " 558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 559 "staticmethod. For more details, please see " 560 "https://pytorch.org/docs/master/notes/extending.func.html" 561 )

File /usr/local/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd..decorate_fwd(*args, *kwargs) 113 if cast_inputs is None: 114 args[0]._fwd_used_autocast = torch.is_autocast_enabled() --> 115 return fwd(args, **kwargs) 116 else: 117 autocast_context = torch.is_autocast_enabled()

File /usr/local/lib/python3.10/site-packages/unsloth/models/_utils.py:385, in Unsloth_Offloaded_Gradient_Checkpointer.forward(ctx, forward_function, hidden_states, args) 383 saved_hidden_states = hidden_states.to("cpu", non_blocking = True) 384 with torch.no_grad(): --> 385 output = forward_function(hidden_states, args) 386 ctx.save_for_backward(saved_hidden_states) 387 ctx.forward_function = forward_function

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, kwargs) 164 output = module._old_forward(*args, *kwargs) 165 else: --> 166 output = module._old_forward(args, kwargs) 167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/llama.py:434, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs) 432 residual = hidden_states 433 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) --> 434 hidden_states, self_attn_weights, present_key_value = self.self_attn( 435 hidden_states=hidden_states, 436 causal_mask=causal_mask, 437 attention_mask=attention_mask, 438 position_ids=position_ids, 439 past_key_value=past_key_value, 440 output_attentions=output_attentions, 441 use_cache=use_cache, 442 padding_mask=padding_mask, 443 ) 444 hidden_states = residual + hidden_states 446 # Fully Connected

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /usr/local/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, kwargs) 164 output = module._old_forward(*args, *kwargs) 165 else: --> 166 output = module._old_forward(args, kwargs) 167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/site-packages/unsloth/models/mistral.py:129, in MistralAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs) 126 pass 127 pass --> 129 A = xformers_attention(Q, K, V, attn_bias = causal_mask) 130 A = A.view(bsz, q_len, n_heads, head_dim) 132 elif HAS_FLASH_ATTENTION and attention_mask is None:

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:268, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op, output_dtype) 156 def memory_efficient_attention( 157 query: torch.Tensor, 158 key: torch.Tensor, (...) 165 outputdtype: Optional[torch.dtype] = None, 166 ) -> torch.Tensor: 167 """Implements the memory-efficient attention mechanism following 168 "Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>](http://arxiv.org/abs/2112.05682%3E%60_). 169 (...) 266 :return: multi-head attention Tensor with shape [B, Mq, H, Kv] 267 """ --> 268 return _memory_efficient_attention( 269 Inputs( 270 query=query, 271 key=key, 272 value=value, 273 p=p, 274 attn_bias=attn_bias, 275 scale=scale, 276 output_dtype=output_dtype, 277 ), 278 op=op, 279 )

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:387, in _memory_efficient_attention(inp, op) 382 def _memory_efficient_attention( 383 inp: Inputs, op: Optional[AttentionOp] = None 384 ) -> torch.Tensor: 385 # fast-path that doesn't require computing the logsumexp for backward computation 386 if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]): --> 387 return _memory_efficient_attention_forward( 388 inp, op=op[0] if op is not None else None 389 ) 391 output_shape = inp.normalize_bmhk() 392 return _fMHA.apply( 393 op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale 394 ).reshape(output_shape)

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/init.py:403, in _memory_efficient_attention_forward(inp, op) 401 output_shape = inp.normalize_bmhk() 402 if op is None: --> 403 op = _dispatch_fw(inp, False) 404 else: 405 _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/dispatch.py:125, in _dispatch_fw(inp, needs_gradient) 116 def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: 117 """Computes the best operator for forward 118 119 Raises: (...) 123 AttentionOp: The best operator for the configuration 124 """ --> 125 return _run_priority_list( 126 "memory_efficient_attention_forward", 127 _dispatch_fw_priority_list(inp, needs_gradient), 128 inp, 129 )

File /usr/local/lib/python3.10/site-packages/xformers/ops/fmha/dispatch.py:65, in _run_priority_list(name, priority_list, inp) 63 for op, not_supported in zip(priority_list, not_supported_reasons): 64 msg += "\n" + _format_not_supported_reasons(op, not_supported) ---> 65 raise NotImplementedError(msg)

NotImplementedError: No operator found for memory_efficient_attention_forward with inputs: query : shape=(8, 569, 8, 4, 128) (torch.bfloat16) key : shape=(8, 569, 8, 4, 128) (torch.bfloat16) value : shape=(8, 569, 8, 4, 128) (torch.bfloat16) attn_bias : <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'> p : 0.0 flshattF@0.0.0 is not supported because: xFormers wasn't build with CUDA support operator wasn't built - see python -m xformers.info for more info cutlassF is not supported because: xFormers wasn't build with CUDA support operator wasn't built - see python -m xformers.info for more info smallkF is not supported because: max(query.shape[-1] != value.shape[-1]) > 32 xFormers wasn't build with CUDA support dtype=torch.bfloat16 (supported: {torch.float32}) attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'> operator wasn't built - see python -m xformers.info for more info operator does not support BMGHK format unsupported embed per head: ### ### 128`

danielhanchen commented 1 month ago

Oh you need to update xformers! Do pip install --upgrade "xformers<0.0.26" for torch 2.2 or lower, and pip install --upgrade xformers for torch 2.3 and above. If that does not work, try

!pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"