unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.4k stars 1.29k forks source link

Unsloth error with trl 0.11.4 #1231

Closed mohit-raghavendra closed 2 weeks ago

mohit-raghavendra commented 2 weeks ago

I have the following configuration for unsloth

trl                               0.11.4
transformers                      4.45.2
unsloth                           2024.10.7
unsloth_zoo                       2024.11.0

I am doing DPO with Llama3.2 on a custom dataset with QLora having just 83 M trainable parameters, on an A40 with 45 GB VRAM. The dataset is a small instruction following dataset, with <1000 tokens. It was working fine, but after I updated unsloth it stopped working, with the following error.

  File "/nethome/mraghavendra6/flash/miniforge3/envs/tuning/lib/python3.12/site-packages/unsloth/kernels/utils.py", line 412, in matmul_lora
    out += (X @ A.to(dtype)) @ (s * B.to(dtype))
           ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~
RuntimeError: CUDA driver error: invalid argument

I don't think this is caused by this> Can you please take a look if unsloth is supposed to work with the above configurations? Or what trl + transformers + unsloth version combination is right

Erland366 commented 2 weeks ago

image Seems working fine on me .-. on RTX4090

I'll try investigate it further

Erland366 commented 2 weeks ago

https://discuss.pytorch.org/t/torch-prod-produces-runtimeerror-cuda-driver-error-invalid-argument/179054/24

Seems like this guy has the same problem, can you try this method of erasing the cache?

mohit-raghavendra commented 2 weeks ago

I dont have a directory named ~/.cache/torch/kernels/

Is there any more information that I can provide to make it easy?

Erland366 commented 2 weeks ago

I feel like perhaps it's like an error CUDA driver maybe? Can you create fresh environment installation?

mohit-raghavendra commented 2 weeks ago

I tried with a fresh conda env, installed unsloth exactly like it is on the installation instructions:

conda create --name unsloth_env \
    python=3.11 \
    pytorch-cuda=12.1 \
    pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
    -y
conda activate unsloth_env

pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps trl peft accelerate bitsandbytes

Ran it again -

  File "/nethome/mraghavendra6/flash/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/uti
ls.py", line 412, in matmul_lora                                                                                 
    out += (X @ A.to(dtype)) @ (s * B.to(dtype))                                                                 
           ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~                                                                 
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 216.00 MiB. GPU 0 has a total capacity of 44.38 GiB of which 34.62 MiB is free. Including non-PyTorch memory, this process has 44.35 GiB memory in use. Of the allocated memory 43.95 GiB is allocated by PyTorch, and 77.80 MiB is reserved by PyTorch but unallocated. If reserved 
but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variab
les)

This is after updating trl to 0.12.0

transformers                      4.46.1
trl                               0.12.0

I am able to run SFT without any issues

danielhanchen commented 2 weeks ago

That's saying you're going out of memory - try reducing your batch size or sequence length

mohit-raghavendra commented 2 weeks ago

I am not sure if it is purely a memory issue, because the exact same code with the exact same data used to work, and now I reduced batch size to 1 and it is still giving an error.

Managed to get it working by explicitly clearing out torch cache before calling this before train()

    torch.cuda.empty_cache()
    gc.collect()