bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.34k stars 637 forks source link

LLama3-8B - FSDP + QLORA results in OOM with 4 A40's #1252

Closed MikaSie closed 5 months ago

MikaSie commented 5 months ago

Hardware: CPU: Xeon® E5-2630 v2 but limited to 16GB as this is what the vast.ai instance has. GPU: 4x A40 --> Total of 180GB

OS Linux

python 3.10

cuda 12.2

packages:

torch==2.3.1
transformers==4.41.2
peft==0.11.1
datasets==2.20.0
accelerate==0.31.0
evaluate==0.4.1
bitsandbytes==0.43.1
huggingface_hub==0.23.4
trl==0.9.4

Issue

Introduction

Hi! I'm trying to fine-tune LLama3-8B on a summarization dataset of about 1500 instances. The dataset contains long documents, often over 8K tokens. I want to use FSDP + QLORA to try and finetune LLama3 8B. When following this guide I was very hopeful this was possible on my setup as I'm finetuning a 8B version instead of the 70B version.

I'm following these two guides as inspiration: bitsandbytes Guide Phil Schmid Guide

Phil Schmid's guide mentions the following: Expected Memory usage: Full-finetuning with FSDP needs ~16X80GB GPUs FSDP + LoRA needs ~8X80GB GPUs FSDP + Q-Lora needs ~2x40GB GPUs FSDP + Q-Lora + CPU offloading needs 4x24GB GPUs, with 22 GB/GPU and 127 GB CPU RAM with a sequence length of 3072 and a batch size of 1. Note: To NOT CPU offloading you need to change the value of fsdp and remove offload. This only works on > 40GB GPUs since it requires more memory.

Accelerate config setup:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: false #Was true before
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Code

quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage=torch.bfloat16,
            )

model = AutoModelForCausalLM.from_pretrained(
            'meta-llama/Meta-Llama-3-8B', 
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa",
            use_cache=False
            )

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B)
tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
            r= 8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
            task_type= 'CAUSAL_LM',
            bias= 'none',

        )

model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
            output_dir = os.path.join('results', model_id, 'output'),
            num_train_epochs = 40,
            per_device_train_batch_size = 1,
            per_device_eval_batch_size = 1, 
            gradient_accumulation_steps = True,
            warmup_ratio = args.warmup_ratio,
            weight_decay = args.weight_decay,
            logging_dir = os.path.join('results', model_id, 'logs'),
            remove_unused_columns = False,        
            load_best_model_at_end = True,
            metric_for_best_model = True,
            save_strategy= "epoch",
            save_total_limit= 2,
            evaluation_strategy = "epoch",
            label_names=["labels"],
            report_to = "wandb",
            logging_strategy = "epoch",
            run_name = model_id,
            eval_accumulation_steps = 1,
            hub_model_id = f"{model_id}",
            gradient_checkpointing= True,
            fp16= args.fp16,
            bf16= args.bf16,
            ddp_find_unused_parameters = True,
            gradient_checkpointing_kwargs= {'use_reentrant': False},
        )

trainer = SFTTrainer(
            model = model, 
            tokenizer = tokenizer, 
            args = training_args,
            train_dataset = dataset["train"],
            eval_dataset = dataset["validation"],
            max_seq_length = context_length_abstractive_model, #8192 
            callbacks = [EarlyStoppingCallback(early_stopping_patience = args.early_stopping_patience)],
            peft_config = lora_config,
            packing= True
            )

trainer.train()

Start training

accelerate launch training.py --bf16

errors:

First is followed the guides exactly and set _fsdp_cpu_ram_efficientloading to true. But when i do this, sometimes the OS would run give a SIGKILL(9) error and stop the process: Scherm­afbeelding 2024-06-17 om 11 09 46 This makes sense as Phil Schmid also recommends a pretty hefty CPU memory: 127 GB CPU RAM with a sequence length of 3072 for a batch size of 1.

But oddly enough, I can run the script currently with _fsdp_cpu_ram_efficient_loading__ with either true or false and not receive the SIGKILL(9) error. However, in both situations I do get the following OOM error:

rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/Thesis/training.py", line 703, in <module>
[rank1]:     trainer.train()
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 440, in train
[rank1]:     output = super().train(*args, **kwargs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3250, in training_step
[rank1]:     self.accelerator.backward(loss)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2134, in backward
[rank1]:     loss.backward(**kwargs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.13 GiB. GPU  has a total capacity of 44.35 GiB of which 20.85 GiB is free. Process 787350 has 23.49 GiB memory in use. Of the allocated memory 18.22 GiB is allocated by PyTorch, and 4.84 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0617 09:10:40.805000 140644428781376 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 3244 closing signal SIGTERM

As you can see, it seems that during the backward pass, the model runs out of memory. I find this pretty odd as I (should/probably) have enough GPU memory to accomodate for the 8B FSDP and QLORA setup.

Possible limitations

CPU has too little ram. The offloading isn't possible because we only have 16GB of CPU ram. But following Phil Schmid's guide and not offloading to the CPU would suffice still, as we use 4 A40's. This is even more odd when you think that I'm using an 8B version, instead of the 70B versions that are used in both guides.

Not using Flash Attention 2 could also be an issue, but as seen in Phil Schmid's guide, SDPA can also be used.

Sequence length is too long, causing OOM. I tried setting the max_sequence_length to 512, but this didn't have any impact on the OOM issue.

Caveat

When i first dove into the rabbithole of FSDP and QLORA I started out simple and just used the following code:

quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            )

model = AutoModelForCausalLM.from_pretrained(
            'meta-llama/Meta-Llama-3-8B', 
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            device_map = 'auto',
            )

I launched the code with:

python3 training.py

This didn't result in an OOM error and I was able to train for 100 steps. This took quite long however and would become too expensive for me as the training would probably last over 200 hours.... I could see that the GPU memory was utilized pretty well and all GPU's were utilized up until 40GB or so. Because this took quite long, I wanted to use QLORA. But I couldn't just use QLORA device_map ='auto' together. That's why I resorted to FSDP in combination with QLORA.

I don't really know why using QLORA in combination with FSDP would then result in the OOM again, making me even more confused.

If you have any ideas, please let me know as I'm getting a bit frustrated after being stuck on this for a few days!

MikaSie commented 5 months ago

I fixed the issue! There were somme things I did wrong:

Looking back, it makes sense that my script worked with:

python3 training.py

When this was run, we were performing DP instead of FSDP. Then it also makes sense that training was 200 hours.

With my current setup I'm able to reduce training to 60-ish hours. With a _per_device_train_batchsize of 1 and _gradient_accumulationsteps of 4, the memory of my GPUs are almost maxed out. I think this is due to the long sequence length that is used.

If anyone has any recommendations on how to speed up the remainder of the training process, feel free to let me know!