Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.69k stars 710 forks source link

OOM Error: CUDA out of memory when finetuning llama3-8b #1358

Closed zhaosheng-thu closed 1 week ago

zhaosheng-thu commented 1 week ago

When I finetune Llama3-8b by finetune/lora.py, OOM occured. My training and dataset parameters:

The parameters and the config ```bash > --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B \ > --precision 'bf16-true' \ > --train.global_batch_size 8 \ > --train.max_seq_length 2048 \ > --data JSON \ > --data.prompt_style 'llama3' \ > --data.json_path /root/szhao/ES-Lora/litllama/ExTES/ExTES.json \ > --data.val_split_fraction 0.1 \ > --data.mask_prompt True \ > --out_dir out/llama3-esconv-test ```

In the Command, the prompt_style 'llama3' is defined by myself. But I encountered the Error as following:

The Error showed in terminal {'checkpoint_dir': PosixPath('checkpoints/meta-llama/Meta-Llama-3-8B'), 'data': JSON(json_path=PosixPath('/root/szhao/ES-Lora/litllama/ExTES/ExTES.json'), mask_prompt=True, val_split_fraction=0.1, prompt_style=, ignore_index=-100, seed=42, num_workers=4), 'devices': 1, 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100), 'logger_name': 'csv', 'lora_alpha': 16, 'lora_dropout': 0.05, 'lora_head': False, 'lora_key': False, 'lora_mlp': False, 'lora_projection': False, 'lora_query': True, 'lora_r': 8, 'lora_value': True, 'out_dir': PosixPath('out/llama3-esconv-test'), 'precision': 'bf16-true', 'quantize': None, 'seed': 1337, 'train': TrainArgs(save_interval=1000, log_interval=1, global_batch_size=8, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=None, epochs=5, max_tokens=None, max_steps=None, max_seq_length=2048, tie_embeddings=None, learning_rate=0.0003, weight_decay=0.02, beta1=0.9, beta2=0.95, max_norm=None, min_lr=6e-05)} Seed set to 1337 Number of trainable parameters: 3,407,872 Number of non-trainable parameters: 8,030,261,248 The longest sequence length in the train data is 1923, the model's maximum sequence length is 1923 and context length is 8192 Validating ... Epoch 1 | iter 1 step 0 | loss train: 2.891, val: n/a | iter time: 956.90 ms Epoch 1 | iter 2 step 0 | loss train: 2.950, val: n/a | iter time: 524.16 ms Epoch 1 | iter 3 step 0 | loss train: 3.036, val: n/a | iter time: 547.41 ms Epoch 1 | iter 4 step 0 | loss train: 2.932, val: n/a | iter time: 653.34 ms Epoch 1 | iter 5 step 0 | loss train: 2.988, val: n/a | iter time: 395.61 ms Epoch 1 | iter 6 step 0 | loss train: 3.014, val: n/a | iter time: 516.08 ms Epoch 1 | iter 7 step 0 | loss train: 3.029, val: n/a | iter time: 741.14 ms Epoch 1 | iter 8 step 1 | loss train: 3.025, val: n/a | iter time: 513.20 ms (step) Epoch 1 | iter 9 step 1 | loss train: 3.058, val: n/a | iter time: 645.27 ms Epoch 1 | iter 10 step 1 | loss train: 3.028, val: n/a | iter time: 693.90 ms Epoch 1 | iter 11 step 1 | loss train: 2.986, val: n/a | iter time: 656.02 ms Epoch 1 | iter 12 step 1 | loss train: 2.994, val: n/a | iter time: 643.70 ms Epoch 1 | iter 13 step 1 | loss train: 2.952, val: n/a | iter time: 469.85 ms Epoch 1 | iter 14 step 1 | loss train: 2.910, val: n/a | iter time: 649.19 ms Epoch 1 | iter 15 step 1 | loss train: 2.868, val: n/a | iter time: 500.48 ms Epoch 1 | iter 16 step 2 | loss train: 2.869, val: n/a | iter time: 547.06 ms (step) Epoch 1 | iter 17 step 2 | loss train: 2.841, val: n/a | iter time: 638.36 ms Epoch 1 | iter 18 step 2 | loss train: 2.899, val: n/a | iter time: 398.46 ms Epoch 1 | iter 19 step 2 | loss train: 2.903, val: n/a | iter time: 649.55 ms Epoch 1 | iter 20 step 2 | loss train: 2.925, val: n/a | iter time: 520.43 ms Epoch 1 | iter 21 step 2 | loss train: 2.927, val: n/a | iter time: 689.26 ms Epoch 1 | iter 22 step 2 | loss train: 2.942, val: n/a | iter time: 525.35 ms Epoch 1 | iter 23 step 2 | loss train: 2.933, val: n/a | iter time: 462.21 ms Epoch 1 | iter 24 step 3 | loss train: 2.916, val: n/a | iter time: 654.28 ms (step) Epoch 1 | iter 25 step 3 | loss train: 2.930, val: n/a | iter time: 537.02 ms Epoch 1 | iter 26 step 3 | loss train: 2.911, val: n/a | iter time: 476.17 ms Epoch 1 | iter 27 step 3 | loss train: 2.913, val: n/a | iter time: 545.19 ms Epoch 1 | iter 28 step 3 | loss train: 2.882, val: n/a | iter time: 528.47 ms Epoch 1 | iter 29 step 3 | loss train: 2.921, val: n/a | iter time: 463.99 ms Epoch 1 | iter 30 step 3 | loss train: 2.899, val: n/a | iter time: 484.63 ms Epoch 1 | iter 31 step 3 | loss train: 2.927, val: n/a | iter time: 390.68 ms Epoch 1 | iter 32 step 4 | loss train: 2.922, val: n/a | iter time: 691.96 ms (step) Epoch 1 | iter 33 step 4 | loss train: 2.919, val: n/a | iter time: 461.07 ms Epoch 1 | iter 34 step 4 | loss train: 2.867, val: n/a | iter time: 690.43 ms Epoch 1 | iter 35 step 4 | loss train: 2.828, val: n/a | iter time: 540.38 ms Traceback (most recent call last): File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 432, in CLI(setup) File "/root/szhao/ES-Lora/litgpt/litgpt/utils.py", line 412, in CLI return CLI(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI return _run_component(components, cfg_init) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/jsonargparse/_cli.py", line 196, in _run_component return component(**cfg) File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 143, in setup fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 866, in launch return self._wrap_and_launch(function, self, *args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 952, in _wrap_and_launch return to_run(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 957, in _wrap_with_setup return to_run(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 196, in main fit( File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 276, in fit logits = model(input_ids, lm_head_chunk_size=128) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 143, in forward output = self._forward_module(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/lora.py", line 545, in forward x = block(x, cos, sin, mask, input_pos) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/model.py", line 187, in forward x = self.mlp(self.norm_2(x)) + x File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/model.py", line 311, in forward x_fc_1 = self.fc_1(x) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/lora.py", line 168, in forward pretrained = self.linear(x) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU

I find it so weird because formally I have finetune llama2-7b by the lit-llama repository on the same dataset with almost same train config, at that time everything went smoothly. Can you help me? Thanks.

rasbt commented 1 week ago

Hm, it could be related to the slightly larger size.

llama2-7b by the lit-llama

I think llama 2 is not supported by lit-llama. Do you perhaps meant llama 7B in lit-llama or llama 2 7B in LitGPT?

If you meant lit-llama, I am curious, does the 7B Llama 2 model work for you in LitGPT?

In any case, you could perhaps try QLoRA or a smaller sequence length to make it work.

yirending commented 1 week ago

With --quantize bnb.nf4, I am able to fine-tune the Llama 3-8B without any problem on a single A10 GPU.

zhaosheng-thu commented 1 week ago

Thanks for all the help. I found that the OOM error vanishes when I choose a smaller max-seq-length. I believe it's because my dataset samples are too long, leading to OOM. When I tried Lora with the Alpaca-2k dataset, it consumed 20.5GB of memory. When I used my dataset without limiting max_seq_length, it would OOM regardless of whether I used --quantize bnb.nf4 or not. The issue was resolved when I limited --max-seq-length 512.