huggingface / optimum-neuron

Easy, fast and very cheap training and inference on AWS Trainium and Inferentia chips.
Apache License 2.0
205 stars 60 forks source link

training loss while fine-tuning llama 3.1 with lora is very high compared to rtx 3090 #721

Open anilozlu opened 1 hour ago

anilozlu commented 1 hour ago

System Info

using Huggingface AMI from AWS marketplace with Ubuntu 22.04
optimum-neuron 0.0.25
transformers 4.45.2
peft 0.13.0
trl 0.11.4
accelerate 0.29.2
torch 2.1.2

Who can help?

@michaelbenayoun

Information

Tasks

Reproduction (minimal, reproducible, runnable)

I am following the tutorial here: https://huggingface.co/docs/optimum-neuron/en/training_tutorials/sft_lora_finetune_llm I have been using the training script found here: https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py I used a trn1.2xlarge instance with 2 neuron cores to train a Llama 3.1 8B using LoRA using tensor parallelism with a degree of 2. However, training loss is very high compared to the same model with same parameters being trained on a single RTX 3090. The training losses look like this: combined I ran these experiments using databricks/databricks-dolly-15k and timdettmers/openassistant-guanaco I also changed the tokenize function under _prepare_non_packed_dataloader in trl/trainer/sft_trainer so that it pads every sample to max_length so it behaves the same as optimum-neuron. My training script for the trn1.2xlarge instance (for dolly dataset, for openassistant dataset I change the formatting function so it just returns examples["text"] directly:

train.py ```python from dataclasses import dataclass, field from datasets import load_from_disk, load_dataset, Dataset from peft import LoraConfig from transformers import ( AutoModelForCausalLM, AutoTokenizer, set_seed, ) from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments from optimum.neuron.distributed import lazy_load_for_parallelism import torch from huggingface_hub import login import os os.environ["WANDB_PROJECT"] = "my_project" os.environ["WANDB_LOG_MODEL"] = "false" os.environ["WANDB_WATCH"] = "all" def format_dolly(examples): output_text = [] for i in range(len(examples["instruction"])): instruction = f"### Instruction\n{examples['instruction'][i]}" context = f"### Context\n{examples['context'][i]}" if len(examples["context"][i]) > 0 else None response = f"### Answer\n{examples['response'][i]}" prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) output_text.append(prompt) return output_text def training_function(script_args, training_args): dataset = load_dataset("databricks/databricks-dolly-15k") tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) tokenizer.pad_token = tokenizer.eos_token config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"], bias="none", task_type="CAUSAL_LM", ) args = training_args.to_dict() sft_config = NeuronSFTConfig( #max_seq_length=1024, #packing=False, **args, ) with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size): model = AutoModelForCausalLM.from_pretrained(script_args.model_id) trainer = NeuronSFTTrainer( args=sft_config, model=model, peft_config=config, tokenizer=tokenizer, train_dataset=dataset, formatting_func=format_dolly ) # Start training #print(trainer.evaluate()) trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload @dataclass class ScriptArguments: model_id: str = field( default="meta-llama/Llama-3.1-8B", metadata={"help": "The model that you want to train from the Hugging Face hub."}, ) def main(): parser = HfArgumentParser([ScriptArguments, NeuronTrainingArguments]) script_args, training_args = parser.parse_args_into_dataclasses() # set seed set_seed(training_args.seed) # run training function training_function(script_args, training_args) if __name__ == "__main__": main() ```

My bash script for graph compilation:

compile.sh ```bash #!/bin/bash set -ex MODEL_NAME="meta-llama/Llama-3.1-8B" huggingface-cli download $MODEL_NAME --exclude "original/*" --token TOKEN export NEURON_FUSE_SOFTMAX=1 export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 export MALLOC_ARENA_MAX=64 export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/" PROCESSES_PER_NODE=2 NUM_EPOCHS=1 TP_DEGREE=2 PP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=8 LOGGING_STEPS=10 OUTPUT_DIR="trn1.2xlarge_databricks-dolly-15k" MAX_STEPS=25 XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_NODE train.py \ --model_id $MODEL_NAME \ --num_train_epochs $NUM_EPOCHS \ --do_train \ --learning_rate 5e-5 \ --warmup_ratio 0.03 \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BS \ --per_device_eval_batch_size $BS \ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ --gradient_checkpointing true \ --bf16 \ --zero_1 false \ --tensor_parallel_size $TP_DEGREE \ --pipeline_parallel_size $PP_DEGREE \ --logging_steps $LOGGING_STEPS \ --save_total_limit 1 \ --output_dir $OUTPUT_DIR \ --lr_scheduler_type "constant" \ --overwrite_output_dir \ --report_to "none" rm -rf $OUTPUT_DIR ```

and my bash script for training:

train.sh ```bash #!/bin/bash set -ex MODEL_NAME="meta-llama/Llama-3.1-8B" HF_TOKEN="TOKEN" export NEURON_FUSE_SOFTMAX=1 export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 export MALLOC_ARENA_MAX=64 export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/" PROCESSES_PER_NODE=2 NUM_EPOCHS=1 TP_DEGREE=2 PP_DEGREE=1 BS=1 GRADIENT_ACCUMULATION_STEPS=8 LOGGING_STEPS=10 OUTPUT_DIR="trn1.2xlarge_databricks-dolly-15k" MAX_STEPS=200 XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE train.py \ --model_id $MODEL_NAME \ --num_train_epochs $NUM_EPOCHS \ --do_train \ --learning_rate 5e-5 \ --warmup_ratio 0.03 \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BS \ --per_device_eval_batch_size $BS \ --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ --gradient_checkpointing true \ --bf16 \ --zero_1 false \ --tensor_parallel_size $TP_DEGREE \ --pipeline_parallel_size $PP_DEGREE \ --logging_steps $LOGGING_STEPS \ --save_total_limit 1 \ --output_dir $OUTPUT_DIR \ --lr_scheduler_type "constant" \ --overwrite_output_dir \ --report_to "wandb" \ --run_name $OUTPUT_DIR \ ```

The script I use to train on RTX 3090:

train.py ```python import os os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"]="1" os.environ["WANDB_PROJECT"] = "my_project" os.environ["WANDB_LOG_MODEL"] = "false" os.environ["WANDB_WATCH"] = "all" from datasets import load_dataset from peft import LoraConfig, TaskType, AutoPeftModelForCausalLM from trl import SFTTrainer, SFTConfig from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM from accelerate import init_empty_weights import torch model_path = "meta-llama/Llama-3.1-8B" dataset_path = "databricks/databricks-dolly-15k" output_dir = model_path.split("/")[-1] + "-" + dataset_path.split("/")[-1] run_name = "3090" + "_" + model_path.split("/")[-1] + "_" + dataset_path.split("/")[-1] tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token def format_dolly(example): instruction = f"### Instruction\n{example['instruction']}" context = f"### Context\n{example['context']}" if len(example["context"]) > 0 else None response = f"### Answer\n{example['response']}" prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) return {"text": prompt} dataset = load_dataset(dataset_path, split="train") dataset = dataset.map(format_dolly) lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"], bias="none", task_type="CAUSAL_LM", ) sft_config = SFTConfig( do_train=True, per_device_train_batch_size=1, gradient_accumulation_steps=8, save_total_limit=1, bf16=True, max_seq_length=1024, output_dir=run_name, dataset_text_field="text", learning_rate=5e-05, warmup_ratio=0.03, lr_scheduler_type="constant", gradient_checkpointing=True, logging_steps=10, report_to="wandb", run_name=run_name, num_train_epochs=1, max_steps=200 ) with init_empty_weights(): model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16) trainer = SFTTrainer( model, tokenizer=tokenizer, train_dataset=dataset, args=sft_config, peft_config=lora_config, packing=False ) trainer.train() ```

Disabling embedding parallelization on the Trainium instance lowers the training loss but it is still consistently higher than the loss on the RTX 3090. Also, with embedding parallelization enabled the model is saved incorrectly. Trained model with embedding parallelization has additional layers base_model.model.lm_head.weight and base_model.model.model.embed_tokens.weight . Additionally only half ofbase_model.model.model.embed_tokens.weight is saved (shape is (64128, 4096) instead of (128256, 4096)) but perhaps this should be another issue.

Expected behavior

I expect the training loss to be much closer to the loss I get when I train the model on an RTX 3090 instead of 2 trainium neuron cores.

anilozlu commented 1 hour ago

Sorry if this is a double ping but I think I made a typo with your handle the first time, @michaelbenayoun can you help with this?