unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.69k stars 1.23k forks source link

OOM when accessing model().loss #365

Closed ant-pls-dev closed 6 months ago

ant-pls-dev commented 6 months ago

Hello,

I would like to access the loss of the model, for example to compute perplexity, on a RTX 3050. Usual inference works great, but accessing model().loss triggers a OOM :

from unsloth import FastLanguageModel
import torch

MAX_SEQ_LEN = 8192

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
    max_seq_length = MAX_SEQ_LEN,
    dtype = None,
    load_in_4bit = True,
)

FastLanguageModel.for_inference(model)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

a=model(input_ids).loss

result :

Traceback (most recent call last):
  File "/workspace/llama3.py", line 38, in <module>
    a=model(input_ids).loss
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py", line 813, in _CausalLM_fast_forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py", line 680, in LlamaModel_fast_forward
    layer_outputs = decoder_layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py", line 413, in LlamaDecoderLayer_fast_forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py", line 308, in LlamaAttention_fast_forward
    Q, K, V = self.apply_qkv(self, hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py", line 63, in original_apply_qkv
    Q = self.q_proj(X)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 468, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 579, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 509, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1349, in dequantize_4bit
    out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacity of 7.78 GiB of which 20.44 MiB is free. Process 17600 has 7.57 GiB memory in use. Of the allocated memory 7.45 GiB is allocated by PyTorch, and 17.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Same result with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

I do not know if it's a bug, unsupported or simply a not enough memory with a 3050 for this use case, or how to work around it.

Thank you

danielhanchen commented 6 months ago

Maybe trying with torch.no_grad() can stop the OOM :)

ant-pls-dev commented 6 months ago
with torch.no_grad():
    a=model(input_ids, labels=input_ids).loss

It just works !

For the record, this is my setup :

==((====))==  Unsloth: Fast Llama patching release 2024.4
   \\   /|    GPU: NVIDIA GeForce RTX 3050. Max memory: 7.777 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
pip freeze ``` accelerate==0.29.3 aiohttp==3.9.5 aiosignal==1.3.1 archspec @ file:///croot/archspec_1709217642129/work asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work astunparse==1.6.3 async-timeout==4.0.3 attrs @ file:///croot/attrs_1695717823297/work beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work bitsandbytes==0.43.1 boltons @ file:///croot/boltons_1677628692245/work Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work certifi @ file:///croot/certifi_1707229174982/work/certifi cffi @ file:///croot/cffi_1700254295673/work chardet @ file:///home/builder/ci_310/chardet_1640804867535/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click @ file:///croot/click_1698129812380/work conda @ file:///croot/conda_1696257509808/work conda-build @ file:///croot/conda-build_1710789183177/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1691418897561/work/src conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_index @ file:///croot/conda-index_1706633791028/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work cryptography @ file:///croot/cryptography_1710350347627/work datasets==2.19.0 decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work dill==0.3.8 distro @ file:///croot/distro_1701455004953/work dnspython==2.6.1 docstring_parser==0.16 einops==0.7.0 exceptiongroup @ file:///croot/exceptiongroup_1706031385326/work executing @ file:///opt/conda/conda-bld/executing_1646925071911/work expecttest==0.2.1 filelock @ file:///croot/filelock_1700591183607/work flash-attn==2.5.7 frozenlist==1.4.1 fsspec==2024.3.1 gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work huggingface-hub==0.22.2 hypothesis==6.99.13 idna @ file:///croot/idna_1666125576474/work ipython @ file:///croot/ipython_1704833016303/work jedi @ file:///tmp/build/80754af9/jedi_1644315229345/work Jinja2 @ file:///croot/jinja2_1706733616596/work jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work jsonpointer==2.1 jsonschema @ file:///croot/jsonschema_1699041609003/work jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work libmambapy @ file:///croot/mamba-split_1698782620632/work/libmambapy markdown-it-py==3.0.0 MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work mdurl==0.1.2 menuinst @ file:///croot/menuinst_1706732933928/work mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 more-itertools @ file:///croot/more-itertools_1700662129964/work mpmath @ file:///croot/mpmath_1690848262763/work multidict==6.0.5 multiprocess==0.70.16 networkx @ file:///croot/networkx_1690561992265/work ninja==1.11.1.1 numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee optree==0.11.0 packaging @ file:///croot/packaging_1710807400464/work pandas==2.2.2 parso @ file:///opt/conda/conda-bld/parso_1641458642106/work peft==0.10.0 pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pillow @ file:///croot/pillow_1707233021655/work pkginfo @ file:///croot/pkginfo_1679431160147/work platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy @ file:///tmp/build/80754af9/pluggy_1648024709248/work prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work protobuf==3.20.3 psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work pyarrow==16.0.0 pyarrow-hotfix==0.6 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work Pygments @ file:///croot/pygments_1684279966437/work pyOpenSSL @ file:///croot/pyopenssl_1708380408460/work PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work python-dateutil==2.9.0.post0 python-etcd==0.4.5 pytz @ file:///croot/pytz_1695131579487/work PyYAML @ file:///croot/pyyaml_1698096049011/work referencing @ file:///croot/referencing_1699012038513/work regex==2024.4.16 requests @ file:///croot/requests_1707355572290/work rich==13.7.1 rpds-py @ file:///croot/rpds-py_1698945930462/work ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work safetensors==0.4.3 sentencepiece==0.2.0 shtab==1.7.1 six @ file:///tmp/build/80754af9/six_1644875935023/work sortedcontainers==2.4.0 soupsieve @ file:///croot/soupsieve_1696347547217/work stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work sympy @ file:///croot/sympy_1701397643339/work tokenizers==0.19.1 tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work toolz @ file:///croot/toolz_1667464077321/work torch==2.2.2 torchaudio==2.2.2 torchelastic==0.2.2 torchvision==0.17.2 tqdm @ file:///croot/tqdm_1679561862951/work traitlets @ file:///croot/traitlets_1671143879854/work transformers==4.40.0 triton==2.2.0 trl==0.8.5 truststore @ file:///croot/truststore_1695244293384/work types-dataclasses==0.6.6 typing_extensions==4.10.0 tyro==0.8.3 tzdata==2024.1 unsloth @ git+https://github.com/unslothai/unsloth.git@ec19e61c854dcf9104386fa63fc6c4f2944d4f35 urllib3 @ file:///croot/urllib3_1707770551213/work wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work xformers==0.0.25.post1 xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///croot/zstandard_1677013143055/work ```

Thank you for this amazing library and support