NVlabs / VILA

VILA - a multi-image visual language model with training, inference and evaluation recipe, deployable from cloud to edge (Jetson Orin and laptops)
Apache License 2.0
971 stars 68 forks source link

Problem training on zero2.json #82

Open Davidup1 opened 5 days ago

Davidup1 commented 5 days ago

thank you for great job! but I met problem when training using zero2.json the loss and learning rate is 0.0 and i'm not sure how to fix it

yaolug commented 5 days ago

could you provide your full environment and command line

Davidup1 commented 5 days ago

I set up the environment following the readme and my script is as follow

#!/bin/bash

export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"

deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 /home/lyj/data/VILA/llava/train/train_mem.py \
    --deepspeed /home/lyj/data/VILA/scripts/zero2.json \
    --model_name_or_path /home/lyj/data/data/VILA1.5-3b/ \
    --version v1 \
    --data_mixture levir_cc \
    --vision_tower google/siglip-so400m-patch14-384 \
    --mm_vision_select_feature cls_patch \
    --mm_projector mlp_downsample \
    --tune_vision_tower False \
    --tune_mm_projector True \
    --tune_cc_projector True \
    --tune_single_projector True \
    --tune_language_model False \
    --mm_vision_select_layer -2 \
    --vision_resolution 576 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --chg True \
    --chg_type Chg2Cap \
    --from_origin True \
    --cc_n_layers 3 \
    --cc_head 8 \
    --cc_dropout 0.1 \
    --bf16 True \
    --output_dir /home/lyj/data/data/checkpoints_1/ \
    --num_train_epochs 10 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 0.5 \
    --save_total_limit 5 \
    --learning_rate 1e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 4096 \
    --gradient_checkpointing True \
    --dataloader_num_workers 16 \
    --lazy_preprocess True \
    --vflan_no_system_prompt True \
    --report_to wandb

the script work well with zero3.json but went wrong with zero2.json

i added some modules and setting to the model structure for my specific tasks. The general modification is that I expanded the structure of mm_projector.

However, although I don’t know the specific reason, by modifying the import method of pre-training parameters, the model can now be trained normally under zero2.json. But I am still curious about the inducing mechanism of the problem.

Lyken17 commented 13 hours ago

This is ususally casued by specific transformer and deepspeed versions. Coud you run environment_setup.sh to reinstall the env?