unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.57k stars 1.3k forks source link

Unexpected OOM When Using use_gradient_checkpointing = "unsloth" #338

Open ansz42 opened 7 months ago

ansz42 commented 7 months ago

Hi! I followed the conda installation and I am using Jupyter notebook in WSL2. System: 32GB RAM RTX 3090 24GB Ryzen 5 5600x

Error message:

RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 trainer_stats = trainer.train()

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:361, in SFTTrainer.train(self, *args, **kwargs)
    358 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
    359     self.model = self._trl_activate_neftune(self.model)
--> 361 output = super().train(*args, **kwargs)
    363 # After training we make sure to retrieve back the original forward pass method
    364 # for the embedding layer by removing the forward post hook.
    365 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/transformers/trainer.py:1780, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1778         hf_hub_utils.enable_progress_bars()
   1779 else:
-> 1780     return inner_training_loop(
   1781         args=args,
   1782         resume_from_checkpoint=resume_from_checkpoint,
   1783         trial=trial,
   1784         ignore_keys_for_eval=ignore_keys_for_eval,
   1785     )

File <string>:355, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/transformers/trainer.py:3036, in Trainer.training_step(self, model, inputs)
   3033     return loss_mb.reduce_mean().detach().to(self.args.device)
   3035 with self.compute_loss_context_manager():
-> 3036     loss = self.compute_loss(model, inputs)
   3038 if self.args.n_gpu > 1:
   3039     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/transformers/trainer.py:3059, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3057 else:
   3058     labels = None
-> 3059 outputs = model(**inputs)
   3060 # Save past state if it exists
   3061 # TODO: this needs to be fixed and made cleaner later.
   3062 if self.args.past_index >= 0:

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/accelerate/utils/operations.py:825, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    824 def forward(*args, **kwargs):
--> 825     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/accelerate/utils/operations.py:813, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    812 def __call__(self, *args, **kwargs):
--> 813     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/unsloth/models/llama.py:882, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
    869 def PeftModelForCausalLM_fast_forward(
    870     self,
    871     input_ids=None,
   (...)
    880     **kwargs,
    881 ):
--> 882     return self.base_model(
    883         input_ids=input_ids,
    884         causal_mask=causal_mask,
    885         attention_mask=attention_mask,
    886         inputs_embeds=inputs_embeds,
    887         labels=labels,
    888         output_attentions=output_attentions,
    889         output_hidden_states=output_hidden_states,
    890         return_dict=return_dict,
    891         **kwargs,
    892     )

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/unsloth/models/mistral.py:213, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    205     outputs = LlamaModel_fast_forward_inference(
    206         self,
    207         input_ids,
   (...)
    210         attention_mask = attention_mask,
    211     )
    212 else:
--> 213     outputs = self.model(
    214         input_ids=input_ids,
    215         causal_mask=causal_mask,
    216         attention_mask=attention_mask,
    217         position_ids=position_ids,
    218         past_key_values=past_key_values,
    219         inputs_embeds=inputs_embeds,
    220         use_cache=use_cache,
    221         output_attentions=output_attentions,
    222         output_hidden_states=output_hidden_states,
    223         return_dict=return_dict,
    224     )
    225 pass
    227 hidden_states = outputs[0]

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/unsloth/models/llama.py:650, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    647 past_key_value = past_key_values[idx] if past_key_values is not None else None
    649 if offloaded_gradient_checkpointing:
--> 650     hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
    651         decoder_layer,
    652         hidden_states,
    653         causal_mask,
    654         attention_mask,
    655         position_ids,
    656         past_key_values,
    657         output_attentions,
    658         use_cache,
    659     )
    661 elif gradient_checkpointing:
    662     def create_custom_forward(module):

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
    550 if not torch._C._are_functorch_transforms_active():
    551     # See NOTE: [functorch vjp and autograd interaction]
    552     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553     return super().apply(*args, **kwargs)  # type: ignore[misc]
    555 if not is_setup_ctx_defined:
    556     raise RuntimeError(
    557         "In order to use an autograd.Function with functorch transforms "
    558         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    559         "staticmethod. For more details, please see "
    560         "https://pytorch.org/docs/master/notes/extending.func.html"
    561     )

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd.<locals>.decorate_fwd(*args, **kwargs)
    113 if cast_inputs is None:
    114     args[0]._fwd_used_autocast = torch.is_autocast_enabled()
--> 115     return fwd(*args, **kwargs)
    116 else:
    117     autocast_context = torch.is_autocast_enabled()

File ~/anaconda3/envs/unsloth_env/lib/python3.10/site-packages/unsloth/models/_utils.py:331, in Unsloth_Offloaded_Gradient_Checkpointer.forward(ctx, forward_function, hidden_states, *args)
    328 @staticmethod
    329 @torch.cuda.amp.custom_fwd
    330 def forward(ctx, forward_function, hidden_states, *args):
--> 331     saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
    332     with torch.no_grad():
    333         (output,) = forward_function(hidden_states, *args)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
danielhanchen commented 7 months ago

@ansz42 Sorry on the delay! Interesting so using our new method rather makes it OOM? Weird

ansz42 commented 7 months ago

@ansz42 Sorry on the delay! Interesting so using our new method rather makes it OOM? Weird

No worries at all! I appreciate your help.

It works well on Colab, but somehow WSL2 + Jupyter notebook causes an OOM. I suspect this might not be an usual OOM though because RAM, shared VRAM and VRAM use get stuck way below the available amount. It doesn't even try going over 24GB before the OOM error. Let me know if you need me to run anything for troubleshooting.

danielhanchen commented 7 months ago

@ansz42 Apologies llama-3 got the better of me! Hmmm WSL ye the shared RAM could be an issue - I'm unsure if WSL randomnly restricts VRAM usage or something

noviljohnson commented 5 months ago

i got the same error

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

i am following the : https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing#scrollTo=2ejIt2xSNKKp colab notebook

current memory stats


GPU = NVIDIA RTX 6000 Ada Generation. Max memory = 47.988 GB.
6.135 GB of memory reserved.```

*parameters *
``` Unsloth - 2x faster free finetuning | Num GPUs = 1
  Num examples = 1,131 | Num Epochs = 1
  Batch size per device = 2 | Gradient Accumulation steps = 2
  Total batch size = 4 | Total steps = 283
  Number of trainable parameters = 167,772,160```

i am using wsl + jupyter notebook in vs code 
danielhanchen commented 5 months ago

@noviljohnson load_in_4bit = False?

noviljohnson commented 5 months ago

thanx i ll try

but i resolved by changing the parameter values per_device_train_batch_size = 2, gradient_accumulation_steps = 8,

thank you

m0nsky commented 5 months ago

Running into the same issue.

Training w/ Unsloth (LLaMA-Factory) through WSL succesfully spills over to system RAM with CUDA Sysmem Fallback Policy enabled, allowing me to train a 16k context 4-bit qlora on a 10GB RTX 3080.

After enabling "use_gradient_checkpointing": "unsloth", it will always OOM. I noticed that it would even OOM in weird scenarios where a 2048 context size worked, but 2047 resulted in OOM, when there was enough VRAM available.

Disabling use_gradient_checkpointing works, but I would love to use it.

Edit 1 Tried load_in_4bit = False, no difference. (except for a lot more memory usage ofcourse)

Edit 2 Fun fact, it is actually possible to train 16k context with load_in_4bit = False using Unsloth, as long as "use_gradient_checkpointing": "unsloth" is disabled. Extremely slow compared to 4bit, but it works!

Edit 3 Last message in the stack trace is about this line (/unsloth/models/_utils.py line 388):

saved_hidden_states = hidden_states.to("cpu", non_blocking = True)

Changing this to

saved_hidden_states = hidden_states.to("cpu", non_blocking = False)

Stops the OOM, it's now training. What are the consequences?

@danielhanchen I hope you don't mind the tag here, but this must be related to the issues we're experiencing.

vladrad commented 5 months ago

Hey! I ran into this in WSL2 as well. I posted in the other thread https://github.com/unslothai/unsloth/issues/600#issuecomment-2181298507 but I think this is due to pinned memory in wsl + coda... when you don't use gradient checkpoint unsloth I don't think it pins anything in memory (or anything massive). With 95gb ram for some reason WSL only allowed 210 mb of pinned memory. To turn it off you can just say use_gradient_checkpointing=True (to use I guess the HF one).

danielhanchen commented 5 months ago

@m0nsky OOO interesting so non_blocking = False works?? Hmm maybe I should make a new method called "unsloth-wsl" for WSL people, to use blocking calls. You will get some slowdowns sadly, since now the transfer of activations to system RAM will block with the GPU

danielhanchen commented 5 months ago

@vladrad Oh yes using use_gradient_checkpointing = True uses normal HF

m0nsky commented 5 months ago

@m0nsky OOO interesting so non_blocking = False works?? Hmm maybe I should make a new method called "unsloth-wsl" for WSL people, to use blocking calls. You will get some slowdowns sadly, since now the transfer of activations to system RAM will block with the GPU

I tried it last night and it seemed to be a lot slower indeed, not sure if it's going to be worth it. :(

vladrad commented 5 months ago

I guess using unsloth-wsl would work solving this issue. I am interested in seeing if we can make it work with unsloth. Im curious what it's doing at that step if you can elaborate maybe I can poke around. Im guessing it has to be the pinnable memory since I have enough vram (a6000 ada 12gb out of 48 used). Im assuming the other option is HF or not using one at all. Is there an option to toss it into ram vs pinnable ram. is there a difference? Some people using unsloth have it working, but I think it's a combo version of the model size and the training data size using the same small memory space.

divakaivan commented 3 weeks ago

wonder if there has been an update or if the author/other participants have found a way around it? 🙏

edit: for me, when use_gradient_checkpointing is unsloth I always get the error as OP. Otherwise, True, False, unsloth-wsl - all result in just going OOM (I am using WSL2 with A6000 48GB and am trying out meta-llama-3.1-8b-instruct-4bit with 16_384 but it gives me OOM with the other options and the OP error when using unsloth)