haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.04k stars 2.21k forks source link

[Usage] OOM when using single A100-40G x8 node (Visual Instruction Tuning) #934

Open zilunzhang opened 10 months ago

zilunzhang commented 10 months ago

Hi Haotian,

OOM happened when I ran "finetune.sh" from scripts/v1_5. I used single node A100-40G x8, without nvlink to fine-tune a 7B LLaVA-1.5.

The estimated training time is ~24 hours when using the default setting. The time increase is understandable (compared with 10 hrs for A100-40G x8 with nvlink). However, when I ran around 100 steps, the OOM happened.

Warning below pop-ups frequently during training.

pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time

Since fine-tuning the 7B model on A100-40G x8 with nvlink worked on your machine, I wonder if the missing nvlink caused this warning and OOM. For example, if the cache can be dequeued more frequently using nvlink rather than PCIE, thus reducing the pressure of memory consumption.

Have you tested tuning the 7B model on the machine that has A100-40G x8 without nvlink?

The training script is attached below.

#!/bin/bash

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
    --version v1 \
    --data_path ./playground/data/llava_v1_5_mix665k.json \
    --image_folder /data/tmp/shz/data/dataset \
    --vision_tower ./checkpoints/clip-vit-large-patch14-336 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-7b-pretrain/mm_projector.bin \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ./checkpoints/llava-v1.5-7b \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 5000 \
    --save_total_limit 5 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

Thanks,

Zilun

FaltingsA commented 7 months ago

@zilunzhang Have you solved this problem?