lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.36k stars 4.47k forks source link

Llama 2 70b qLoRA training not converging #2578

Open alwayshalffull opened 10 months ago

alwayshalffull commented 10 months ago

Hi folks,

I'm running into an issue finetuning the 70B Llama 2 model with 4bit qLoRA using the FastChat package, and I'm wondering if anyone else has encountered similar issues or has suggestions for a fix. Briefly, here's my command to train, based on train_lora.sh script included with FastChat:

deepspeed fastchat/train/train_lora.py \
    --model_name_or_path meta-llama/Llama-2-70b-hf  \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --data_path ~/data.json \
    --output_dir ~/.checkpoints \
    --num_train_epochs 3 \
    --bf16 True \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "steps" \
    --eval_steps 100 \
    --save_strategy "steps" \
    --save_steps 200 \
    --save_total_limit 2 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_strategy "steps" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 4096 \
    --q_lora True \
    --deepspeed playground/deepspeed_config_s2.json \
    --gradient_checkpointing True \
    --flash_attn True \
    --lazy_preprocess True

One notable change is that in train_lora.py, I import llama2_flash_attn_monkey_patch instead of llama_flash_attn_monkey_patch.

When I ran my training job, I noticed the output of the model was quite poor. Among other things, it wasn't properly stopping at the end of messages from the Assistant, and instead would continue to generate full conversations after one input message from the User at inference time. I noticed the loss didn't converge as well as the non-qLoRA jobs I've run, and instead was oscillating around 0.6-1.0 during epochs 2 and 3, when it should usually decrease to around 0.2-0.3 at the end of 3 epochs.

Has anyone encountered similar issues? If so, how did you solve them? Thanks in advance!

alwayshalffull commented 10 months ago

Environment was run on CUDA 12.1, 4xH100, with these package versions:

accelerate==0.23.0
bitsandbytes==0.41.1
deepspeed==0.10.3
flash-attn==2.3.0
fschat==0.2.25
peft==0.5.0
tokenizers==0.13.3
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
transformers==4.34.dev0