axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.44k stars 800 forks source link

Llama 3 8b OOM with GaLore on 2x A100s (Mistral 7b is fine?) #1641

Open e-p-armstrong opened 3 months ago

e-p-armstrong commented 3 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

Llama 3 8b, with only one billion more parameters, should presumably be able to GaLore train at least on 2x A100s (Mistral v0.2 can train on 1x A100).

Current behaviour

Llama 3 8b OOMs immediately when being tuned with GaLore at 8k sequence length even if obscene amounts of compute are thrown at it.

Steps to reproduce

  1. Rent 2x A100s on Vast.ai or any other provider
  2. Run a training run with the provided config (can use any pretraining data and finetuning data as stand-ins for unavailable datasets)
  3. Observe near-immediate OOM.

Config yaml

base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: json
    data_files: pretraining_vision.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_VISION.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_rag_VISION.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_wiki.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_WIKI.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_WIKI.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_api.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_API.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_API.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: pretraining_docs.json
    ds_type: json
    type: completion
  - path: json
    data_files: simplified_data_rag_DOCS.jsonl
    ds_type: json
    type: sharegpt
  - path: json
    data_files: simplified_data_no_rag_DOCS.jsonl
    ds_type: json
    type: sharegpt
dataset_prepared_path: last_run_prepared
output_dir: ./verus-out

sequence_len: 8100
sample_packing: true
pad_to_sequence_len: true

wandb_project: verus-llama-experiment-2
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 6
eval_batch_size: 6
num_epochs: 5
optimizer: galore_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0000035
cosine_min_lr_ratio: 0
weight_decay: 0 # no weight decay to maximize fact memorization (thanks cgato!)
# adamw hyperparams
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 0.00000001
# Gradient clipping max norm
max_grad_norm: 1.0
noisy_embedding_alpha: 0 # no noisy embedding to ensure maximal memorization 

optim_args:
# For Galore Optimizers the following optim_args are available
    rank: 256 # type: int
    update_proj_gap: 200  # type: int
    scale: 0.25  # type: float
    proj_type: "std" # type: str, default = std

optim_target_modules: 
  - mlp
  - self_attn
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
auto_resume_from_checkpoints: false
eval_steps: 10
saves_per_epoch: 1
eval_sample_packing: false
save_total_limit: 2
debug:
deepspeed: deepspeed_configs/zero2.json
special_tokens:
  pad_token: "<|end_of_text|>"

Possible solution

I was able to finetune Llama 3 8b instruct by reducing the sequence length to around 2000 tokens. However unless I'm missing something big that makes Llama very inefficient with GaLore, presumably I should be able to full finetune at 8000 sequence length if I use 2 whole A100s.

Which Operating Systems are you using?

Python Version

3.10.14

axolotl branch-commit

whatever the official docker image is on.

Acknowledgements

e-p-armstrong commented 3 months ago

Edit: it also OOMs on 4x A100s, something is definitely wrong here.

image

winglian commented 3 months ago

How much VRAM does mistral 7B use w/ galore on your setup? Keep in mind that llama-3 has a much larger embeddings layer than Mistral (128k vs 32k) which significantly increases VRAM use.

I would start by decreasing the batch sizes and using the unsloth gradient checkpointing.

micro_batch_size: 4
eval_batch_size: 4
gradient_checkpointing: unsloth
e-p-armstrong commented 3 months ago

VRAM usage w/ Mistral 7b seems to be about 58 gigabytes only, on 1x A100: image

As for Llama 3, I decreased the batch size to 1,set gradient checkpointing to unsloth, and increased gradient accumulation steps to 6. And rented out 8x A6000s. Still OOMs even with 383.9 GB of VRAM available for finetuning an 8b model. This can't be right, can it?

image image

Edit: this may be related to issue #1448

Edit 2: Was using the wrong config -- it does not OOM 8x A6000s. However it IS using 273 GB of VRAM. Seems a bit high?

e-p-armstrong commented 3 months ago

Update: Getting 343GB usage when finetuning llama 3 8b. This has got to be wrong. No idea how to even begin to address this, however. image

Abhis-123 commented 2 months ago

@e-p-armstrong did you got any workaround

e-p-armstrong commented 4 weeks ago

@Abhis-123 galore and multi-gpus do not play nice. Gotta use a single GPU or multi gpus w/ paged adamw and deepspeed (look at config.qmd for options I think)