axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.55k stars 816 forks source link

OOM On Galore Axolotl #1448

Open m626zNq opened 5 months ago

m626zNq commented 5 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

Should start training without OOM, like Llama factory.

Current behaviour

Causing OOM issue on axolotl with my config. LLaMA Factory acted fine but axolotl is hating on me. On llama factory i was able to do 16bit, and 1024 rank, and 8k context, worked fine on same gpu. axolotl wont even work with 8bit and 128 rank, at 4k context,(out of mem)

I have tried:

Steps to reproduce

install galore: pip install galore-torch run the config posted below

Config yaml

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
  - path: Walmart-the-bag/alpaca-ingen
    type:
      field_instruction: instruction
      field_output: output
      format: "\n### Instruction:\n{instruction}\n### Response:\n"
      no_input_format: "\n### Instruction:\n{instruction}\n### Response:\n"
dataset_prepared_path:
val_set_size: 0.05
output_dir: /notebooks/output

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
optim_args:
  rank: 128
  update_proj_gap: 200
  scale: 0.25
  proj_type: std
optim_target_modules:
  - q_proj
  - v_proj
  - linear
  - mlp
  - attn
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: galore_adafactor
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 0
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

jaredquekjz commented 5 months ago

yes same issue here. Even a 7B model takes a LOT of memory - much higher than the <24gb promised in the original repo. Is "activation checkpointing" of the repo equivalent to the "gradient checkpointing" in axolotl? My yaml for Yi (also adapted for Mistral-Hermes) :


base_model: NousResearch/Nous-Hermes-2-Yi-34B
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: /workspace/axolotl/runpod/psychicOlierYidup.jsonl
    type: completion
dataset_prepared_path: 
val_set_size: 0.0
output_dir: /workspace/axolotl/model

sequence_len: 1200
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project: huggingface
wandb_entity: singaporespprtsschool
wandb_watch:
wandb_run_id: PsychicYiGalore26Mar
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 7
optimizer: galore_adamw_8bit 
optim_args:
  rank: 128
  update_proj_gap: 200
  scale: 0.25
  proj_type: std
optim_target_modules:
  - mlp
  - attn
lr_scheduler: cosine
learning_rate: 1e-4

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 30
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true

warmup_steps: 1000
evals_per_epoch: 
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: 
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<|startoftext|>"
  eos_token: "<|im_end|>"
  pad_token: "<unk>"
  unk_token: "<unk>"
winglian commented 5 months ago

Did you try the galore 8bit variants?

jaredquekjz commented 5 months ago

I used the 8bit optimiser as seen above. Hermes 7b takes a shocking 36gb or so at seqlen 1200. And in theory - Yi is supposed to fit on a H100 with Galore - but it will OOM. How can the above yml be optimised further?

jaredquekjz commented 5 months ago

@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:

# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise

According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see https://github.com/huggingface/transformers/pull/29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".

But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.

winglian commented 5 months ago

@jaredquekjz the 24GB from the paper is for a 7B parameter model. You're using yi-34B. That's still going to require much more VRAM, probably at least an 80GB A100 to full finetune with Galore.

winglian commented 5 months ago

@m626zNq set flash_attention: true, without flash attention it's going to OOM almost every time no matter what.

jaredquekjz commented 5 months ago

@winglian - thanks for input. I used both the Yi and the Hermes-Mistral 7B for trial. For both - we need the layer wise optimiser for the memory savings to be maximum (20gb for Hermes as reported). But as shared - the layerwise opt may not be fully working stably yet (no grad norm and no loss decrease) - at least when I last trialed. Yi can load in one H100 with layerwise but v slow..

{'loss': 3.0955, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.099, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0705, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0608, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.1196, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0307, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
winglian commented 5 months ago

Yeah. Layerwise also requires that you use a gradient accumulation steps value of 1

m626zNq commented 5 months ago

@winglian I have tried all of those, still get OOM. I have tried every optimizer of galore, flash attention, deepspeed, etc.. *sorry for late response

m626zNq commented 5 months ago

@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:

# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise

According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see huggingface/transformers#29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".

But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.

Still OOM. No idea what is going wrong.

nelaturuharsha commented 4 months ago

Just a thought [could be wrong] here due to a similar discussion I had: as far as I understand -- GaLore is run completely in BFloat16 precision without any automatic mixed precision. My sense is that using accelerate under the hood, is AMP being used which obviously requires more memory? [the SVD is done in float32 IIRC] -- not sure exactly though. Reference here

e-p-armstrong commented 3 months ago

I am encountering similar issues -- way too much VRAM is being used for GaLore tuning Llama 8b, for me (280 GB on 8x A6000s!) Something definitely seems wrong here. If the paper gives 24GB for a 7b model, presumably it should not take 280 GB for an 8b, even with a larger tokenizer?

j-datta commented 2 months ago

Everytime I've tried to fine tune Mistral-7b with SquAD dataset, I've got OOM error. I'm using 2xV100 (32GB each).

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import datasets
from transformers import AutoTokenizer, AutoModelForCausalLM
import trl
from trl import SFTConfig, SFTTrainer

# Load the dataset
full_train_dataset = datasets.load_dataset('rajpurkar/squad_v2', split='train')
train_dataset = full_train_dataset.select(range(1000))

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").half()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Preprocess the dataset
def preprocess_function(examples):
    inputs = [q + " " + c for q, c in zip(examples["question"], examples["context"])]
    targets = [a["text"][0] if len(a["text"]) > 0 else "" for a in examples["answers"]]
    model_inputs = tokenizer(inputs, max_length=16, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=16, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)

# Define training arguments
args = SFTConfig(
    output_dir="/home/IAIS/jdatta/teacher_model/test-galore",
    max_steps=1000,
    per_device_train_batch_size=1,
    fp16=True,
    dataset_text_field='input_ids',
    max_seq_length=16,
    learning_rate=0.0002,
    gradient_checkpointing=True,
    lr_scheduler_type="cosine",
    optim_args="rank=16, update_proj_gap=100, scale=2",
    warmup_ratio=0.1,
    weight_decay=0.0,
    gradient_accumulation_steps=1,
    optim="galore_adamw_8bit_layerwise",
    optim_target_modules=['linear', 'v_proj', 'mlp', 'q_proj', 'o_proj'],
    run_name='mistral_squad_finetune',
    report_to=[],
)

# Initialize the trainer
trainer = SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

# Train the model
torch.cuda.empty_cache()
trainer.train()

error: OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU