CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.45k stars 470 forks source link

OOM error with PEFT LoRA on Llama2-7B #601

Open arpaiva opened 2 hours ago

arpaiva commented 2 hours ago

🐛 Describe the bug

I'm trying to finetune Llama2-7B (to reproduce the experiments in a paper) using PEFT LoRA (0.124% of trainable params). However, this results in an out-of-memory (OOM) error on a 32GB V100 GPU. Using multiple GPUs or setting the Accelerate Deepspeed config to allow CPU offloading of optimizer states and weights doesn't help and yields the same OOM error. It seems that once it sees that the initial model fits in one GPU, it assumes that everything else will also fit. Or maybe everything is indeed supposed to fit but it is wasting GPU memory somehow. Any clues on how to fix this?

Accelerate config:

compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
  train_batch_size: NUM_PROCS
  train_micro_batch_size_per_gpu: 1
downcast_bf16: no
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: NUM_PROCS
rdzv_backend: static
same_network: true
use_cpu: false

(NUM_PROCS gets replaced with the number of GPUs)

Error:

File "examples/mmlu_sft.py", line 335, in <module>
main(hparams)
File "examples/mmlu_sft.py", line 324, in main
trainer = trlx.train(
File "/scratch/repo/trlx/trlx.py", line 129, in train
trainer.learn()
File "/scratch/repo/trlx/trainer/accelerate_base_trainer.py", line 768, in learn
loss, stats = self.loss(microbatch)
File "/scratch/repo/trlx/trainer/accelerate_sft_trainer.py", line 70, in loss
loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1040, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 581, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 569, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.8/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 918, in forward
return self.base_model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/tuners/tuners_utils.py", line 94, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 809, in forward
outputs = self.model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 697, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 426, in forward
hidden_states = self.mlp(hidden_states)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 220, in forward
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB (GPU 0; 31.74 GiB total capacity; 30.93 GiB already allocated; 59.31 MiB free; 31.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Which trlX version are you using?

0.7.0

Additional system and package information

Python=3.8.10; transformers=4.32.1; Linux x64

arpaiva commented 2 hours ago

By the way, I've tried many, many variations on the Deepspeed settings (I believe I even tried not using accelerate when using only 1 GPU) and got OOM every time.