Closed m626zNq closed 3 weeks ago
yes same issue here. Even a 7B model takes a LOT of memory - much higher than the <24gb promised in the original repo. Is "activation checkpointing" of the repo equivalent to the "gradient checkpointing" in axolotl? My yaml for Yi (also adapted for Mistral-Hermes) :
base_model: NousResearch/Nous-Hermes-2-Yi-34B
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: /workspace/axolotl/runpod/psychicOlierYidup.jsonl
type: completion
dataset_prepared_path:
val_set_size: 0.0
output_dir: /workspace/axolotl/model
sequence_len: 1200
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project: huggingface
wandb_entity: singaporespprtsschool
wandb_watch:
wandb_run_id: PsychicYiGalore26Mar
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 7
optimizer: galore_adamw_8bit
optim_args:
rank: 128
update_proj_gap: 200
scale: 0.25
proj_type: std
optim_target_modules:
- mlp
- attn
lr_scheduler: cosine
learning_rate: 1e-4
train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 30
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 1000
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|im_end|>"
pad_token: "<unk>"
unk_token: "<unk>"
Did you try the galore 8bit variants?
I used the 8bit optimiser as seen above. Hermes 7b takes a shocking 36gb or so at seqlen 1200. And in theory - Yi is supposed to fit on a H100 with Galore - but it will OOM. How can the above yml be optimised further?
@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see https://github.com/huggingface/transformers/pull/29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".
But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.
@jaredquekjz the 24GB from the paper is for a 7B parameter model. You're using yi-34B. That's still going to require much more VRAM, probably at least an 80GB A100 to full finetune with Galore.
@m626zNq set flash_attention: true
, without flash attention it's going to OOM almost every time no matter what.
@winglian - thanks for input. I used both the Yi and the Hermes-Mistral 7B for trial. For both - we need the layer wise optimiser for the memory savings to be maximum (20gb for Hermes as reported). But as shared - the layerwise opt may not be fully working stably yet (no grad norm and no loss decrease) - at least when I last trialed. Yi can load in one H100 with layerwise but v slow..
{'loss': 3.0955, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.099, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0705, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0608, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.1196, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0307, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
Yeah. Layerwise also requires that you use a gradient accumulation steps value of 1
@winglian I have tried all of those, still get OOM. I have tried every optimizer of galore, flash attention, deepspeed, etc.. *sorry for late response
@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:
# - galore_adamw_layerwise # - galore_adamw_8bit_layerwise # - galore_adafactor_layerwise
According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see huggingface/transformers#29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".
But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.
Still OOM. No idea what is going wrong.
Just a thought [could be wrong] here due to a similar discussion I had: as far as I understand -- GaLore is run completely in BFloat16 precision without any automatic mixed precision. My sense is that using accelerate under the hood, is AMP being used which obviously requires more memory? [the SVD is done in float32 IIRC] -- not sure exactly though. Reference here
I am encountering similar issues -- way too much VRAM is being used for GaLore tuning Llama 8b, for me (280 GB on 8x A6000s!) Something definitely seems wrong here. If the paper gives 24GB for a 7b model, presumably it should not take 280 GB for an 8b, even with a larger tokenizer?
Everytime I've tried to fine tune Mistral-7b with SquAD dataset, I've got OOM error. I'm using 2xV100 (32GB each).
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import datasets
from transformers import AutoTokenizer, AutoModelForCausalLM
import trl
from trl import SFTConfig, SFTTrainer
# Load the dataset
full_train_dataset = datasets.load_dataset('rajpurkar/squad_v2', split='train')
train_dataset = full_train_dataset.select(range(1000))
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").half()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Preprocess the dataset
def preprocess_function(examples):
inputs = [q + " " + c for q, c in zip(examples["question"], examples["context"])]
targets = [a["text"][0] if len(a["text"]) > 0 else "" for a in examples["answers"]]
model_inputs = tokenizer(inputs, max_length=16, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=16, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
# Define training arguments
args = SFTConfig(
output_dir="/home/IAIS/jdatta/teacher_model/test-galore",
max_steps=1000,
per_device_train_batch_size=1,
fp16=True,
dataset_text_field='input_ids',
max_seq_length=16,
learning_rate=0.0002,
gradient_checkpointing=True,
lr_scheduler_type="cosine",
optim_args="rank=16, update_proj_gap=100, scale=2",
warmup_ratio=0.1,
weight_decay=0.0,
gradient_accumulation_steps=1,
optim="galore_adamw_8bit_layerwise",
optim_target_modules=['linear', 'v_proj', 'mlp', 'q_proj', 'o_proj'],
run_name='mistral_squad_finetune',
report_to=[],
)
# Initialize the trainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
# Train the model
torch.cuda.empty_cache()
trainer.train()
error: OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU
Please check that this issue hasn't been reported before.
Expected Behavior
Should start training without OOM, like Llama factory.
Current behaviour
Causing OOM issue on axolotl with my config. LLaMA Factory acted fine but axolotl is hating on me. On llama factory i was able to do 16bit, and 1024 rank, and 8k context, worked fine on same gpu. axolotl wont even work with 8bit and 128 rank, at 4k context,(out of mem)
I have tried:
Steps to reproduce
install galore: pip install galore-torch run the config posted below
Config yaml
Possible solution
No response
Which Operating Systems are you using?
Python Version
3.11
axolotl branch-commit
main
Acknowledgements