haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.03k stars 2.21k forks source link

[Usage] Pretrain does not converge #368

Closed tfzhou closed 1 year ago

tfzhou commented 1 year ago

Describe the issue

Hi Haotian, thanks for the efforts on the project. At the moment I am trying to reproduce the pretrain stage, but got stuck in it. I have tried to train from various language models vicuna-7b-v1.3/v1.5, Llama-2-7b-chat-hf using deepspeed with zero2 or zero3 configurations. Unfortunately, these experiments did not go well -- training loss fails to converge and I found that LR schedule did not follow 'cosine' as specified in the command. I am unfamiliar with deepspeed, and uncertain whether the issue is from deepspeed. More details provided below and appreciate for your help.

Btw. I used 4 A100 with 40GB memory for experiments.

Command:

# Uncomment and set the following variables correspondingly to run this script:

#MODEL_VERSION=vicuna-7b-v1.3
MODEL_VERSION=Llama-2-7b-chat-hf

########### DO NOT CHANGE ###########
########### USE THIS FOR BOTH ###########
PROMPT_VERSION=plain
########### DO NOT CHANGE ###########

deepspeed llava/train/train_mem.py \
    --model_name_or_path llama/$MODEL_VERSION \
    --version $PROMPT_VERSION \
    --data_path dataset/llava/LLaVA-CC3M-Pretrain-595K/chat.json \
    --image_folder dataset/llava/LLaVA-CC3M-Pretrain-595K/images \
    --vision_tower openai/clip-vit-large-patch14 \
    --tune_mm_mlp_adapter True \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 24000 \
    --save_total_limit 1 \
    --learning_rate 2e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb \
    --deepspeed scripts/zero2.json \

zero2 (not changed)

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto"
    }
}

Screenshots: Screenshot 2023-08-12 at 14 13 05 Screenshot 2023-08-12 at 14 13 25

tfzhou commented 1 year ago

Just found that the issue stems from flash-attention. after turning into train.py, the pretrain works properly.

@haotian-liu what's the version of flash-att you using? my version is 2.0.4. probably I should use 1.x? Beyond this, do you observe performance differences when using train.py or train_mem.py?

haotian-liu commented 1 year ago

@tfzhou

I have locally tested this again on 2x 3090, per-device batch size=16, on llama-2-7b-chat. train.py and train_mem.py works similarly for me. I am using flash-attention 2.0.4, pytorch 2.0.1, and my cuda version is 11.7. One thing is to make sure that the cuda verision of pytorch and your nvcc when you compile flash attention is the same. (please kindly let me know if this is the case, so that other community members can benefit as well :)

You may choose to downgrade to flash attention 1.x, and our code base currently support both 1.x and 2.x for A100s.

Also I attached the log of first 35 training steps on 2x 3090 (total bs: 16x2=32). It seems that your LR is not correctly decayed as the warmup should only be 3%.

train_mem.py ``` {'loss': 8.0312, 'learning_rate': 3.5778175313059034e-06, 'epoch': 0.0} {'loss': 7.9531, 'learning_rate': 7.155635062611807e-06, 'epoch': 0.0} {'loss': 7.7969, 'learning_rate': 1.073345259391771e-05, 'epoch': 0.0} {'loss': 8.125, 'learning_rate': 1.4311270125223614e-05, 'epoch': 0.0} {'loss': 7.8594, 'learning_rate': 1.7889087656529517e-05, 'epoch': 0.0} {'loss': 7.7656, 'learning_rate': 2.146690518783542e-05, 'epoch': 0.0} {'loss': 7.4219, 'learning_rate': 2.5044722719141324e-05, 'epoch': 0.0} {'loss': 7.0, 'learning_rate': 2.8622540250447228e-05, 'epoch': 0.0} {'loss': 6.8438, 'learning_rate': 3.2200357781753134e-05, 'epoch': 0.0} {'loss': 6.5469, 'learning_rate': 3.5778175313059034e-05, 'epoch': 0.0} {'loss': 6.4688, 'learning_rate': 3.935599284436494e-05, 'epoch': 0.0} {'loss': 6.3594, 'learning_rate': 4.293381037567084e-05, 'epoch': 0.0} {'loss': 6.2031, 'learning_rate': 4.651162790697674e-05, 'epoch': 0.0} {'loss': 6.1875, 'learning_rate': 5.008944543828265e-05, 'epoch': 0.0} {'loss': 6.0625, 'learning_rate': 5.366726296958855e-05, 'epoch': 0.0} {'loss': 6.0938, 'learning_rate': 5.7245080500894455e-05, 'epoch': 0.0} {'loss': 5.8281, 'learning_rate': 6.082289803220036e-05, 'epoch': 0.0} {'loss': 5.9531, 'learning_rate': 6.440071556350627e-05, 'epoch': 0.0} {'loss': 5.6719, 'learning_rate': 6.797853309481217e-05, 'epoch': 0.0} {'loss': 5.5781, 'learning_rate': 7.155635062611807e-05, 'epoch': 0.0} {'loss': 5.4688, 'learning_rate': 7.513416815742398e-05, 'epoch': 0.0} {'loss': 5.4219, 'learning_rate': 7.871198568872988e-05, 'epoch': 0.0} {'loss': 5.4062, 'learning_rate': 8.228980322003578e-05, 'epoch': 0.0} {'loss': 5.4375, 'learning_rate': 8.586762075134168e-05, 'epoch': 0.0} {'loss': 5.3594, 'learning_rate': 8.94454382826476e-05, 'epoch': 0.0} {'loss': 5.2031, 'learning_rate': 9.302325581395348e-05, 'epoch': 0.0} {'loss': 5.1406, 'learning_rate': 9.660107334525938e-05, 'epoch': 0.0} {'loss': 4.9531, 'learning_rate': 0.0001001788908765653, 'epoch': 0.0} {'loss': 4.9844, 'learning_rate': 0.0001037567084078712, 'epoch': 0.0} {'loss': 5.0938, 'learning_rate': 0.0001073345259391771, 'epoch': 0.0} {'loss': 4.8594, 'learning_rate': 0.00011091234347048301, 'epoch': 0.0} {'loss': 4.9688, 'learning_rate': 0.00011449016100178891, 'epoch': 0.0} {'loss': 4.9844, 'learning_rate': 0.00011806797853309481, 'epoch': 0.0} {'loss': 4.875, 'learning_rate': 0.00012164579606440072, 'epoch': 0.0} {'loss': 4.9219, 'learning_rate': 0.0001252236135957066, 'epoch': 0.0} 0%|▎| 35/18606 ```
train.py ``` {'loss': 7.9062, 'learning_rate': 3.5778175313059034e-06, 'epoch': 0.0} {'loss': 7.5625, 'learning_rate': 7.155635062611807e-06, 'epoch': 0.0} {'loss': 7.8438, 'learning_rate': 1.073345259391771e-05, 'epoch': 0.0} {'loss': 7.8906, 'learning_rate': 1.4311270125223614e-05, 'epoch': 0.0} {'loss': 7.75, 'learning_rate': 1.7889087656529517e-05, 'epoch': 0.0} {'loss': 7.5312, 'learning_rate': 2.146690518783542e-05, 'epoch': 0.0} {'loss': 7.4844, 'learning_rate': 2.5044722719141324e-05, 'epoch': 0.0} {'loss': 7.1562, 'learning_rate': 2.8622540250447228e-05, 'epoch': 0.0} {'loss': 6.875, 'learning_rate': 3.2200357781753134e-05, 'epoch': 0.0} {'loss': 6.7188, 'learning_rate': 3.5778175313059034e-05, 'epoch': 0.0} {'loss': 6.6875, 'learning_rate': 3.935599284436494e-05, 'epoch': 0.0} {'loss': 6.5781, 'learning_rate': 4.293381037567084e-05, 'epoch': 0.0} {'loss': 6.3594, 'learning_rate': 4.651162790697674e-05, 'epoch': 0.0} {'loss': 6.4219, 'learning_rate': 5.008944543828265e-05, 'epoch': 0.0} {'loss': 6.1094, 'learning_rate': 5.366726296958855e-05, 'epoch': 0.0} {'loss': 6.1719, 'learning_rate': 5.7245080500894455e-05, 'epoch': 0.0} {'loss': 5.9219, 'learning_rate': 6.082289803220036e-05, 'epoch': 0.0} {'loss': 6.0, 'learning_rate': 6.440071556350627e-05, 'epoch': 0.0} {'loss': 5.7969, 'learning_rate': 6.797853309481217e-05, 'epoch': 0.0} {'loss': 5.6562, 'learning_rate': 7.155635062611807e-05, 'epoch': 0.0} {'loss': 5.5625, 'learning_rate': 7.513416815742398e-05, 'epoch': 0.0} {'loss': 5.5156, 'learning_rate': 7.871198568872988e-05, 'epoch': 0.0} {'loss': 5.4062, 'learning_rate': 8.228980322003578e-05, 'epoch': 0.0} {'loss': 5.5, 'learning_rate': 8.586762075134168e-05, 'epoch': 0.0} {'loss': 5.2344, 'learning_rate': 8.94454382826476e-05, 'epoch': 0.0} {'loss': 5.0781, 'learning_rate': 9.302325581395348e-05, 'epoch': 0.0} {'loss': 5.1562, 'learning_rate': 9.660107334525938e-05, 'epoch': 0.0} {'loss': 4.9531, 'learning_rate': 0.0001001788908765653, 'epoch': 0.0} {'loss': 4.8906, 'learning_rate': 0.0001037567084078712, 'epoch': 0.0} {'loss': 5.0469, 'learning_rate': 0.0001073345259391771, 'epoch': 0.0} {'loss': 4.875, 'learning_rate': 0.00011091234347048301, 'epoch': 0.0} {'loss': 4.9219, 'learning_rate': 0.00011449016100178891, 'epoch': 0.0} {'loss': 4.8594, 'learning_rate': 0.00011806797853309481, 'epoch': 0.0} {'loss': 4.7188, 'learning_rate': 0.00012164579606440072, 'epoch': 0.0} {'loss': 4.8594, 'learning_rate': 0.0001252236135957066, 'epoch': 0.0} 0%|▎| 35/18606 ```
haotian-liu commented 1 year ago

Also, please check the transformers version:

    "deepspeed==0.9.5",
    "peft==0.4.0",
    "transformers==4.31.0",
    "accelerate==0.21.0",
    "bitsandbytes==0.41.0",
tfzhou commented 1 year ago

Thanks @haotian-liu.

One thing is to make sure that the cuda verision of pytorch and your nvcc when you compile flash attention is the same.

I am pretty sure that different versions are used in my setup. I will try to fix this and let u know.

tfzhou commented 1 year ago

btw. after turning into train.py mode, the training works as expected and lr decay is not an issue any more.

haotian-liu commented 1 year ago

@tfzhou I see. The only drawback of using train.py is that it will be slower, and use more memory, which will be more prominent when you switch to finetune mode.

tfzhou commented 1 year ago

After recompiled flash-att using a matched nvcc, the issue has been fixed. Thanks @haotian-liu

wizyoung commented 1 year ago

@haotian-liu Can you post your full train log in the pre-training stage for reference?