haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.24k stars 2.24k forks source link

[Question] Can not reproduce the Acc. on Science QA datasets #107

Open Unrealluver opened 1 year ago

Unrealluver commented 1 year ago

Question

Thanks for running the script for SQA finetuning. After the fine-tuning of SQA. I find the result for the SQA test set (Total: 4241, Correct: 3670, Accuracy: 86.54%) is not good as the result reported in the paper (Accuracy: 90.92%). Could you please share some advice for fixing the mismatch reproduced results?

Here are my running scripts:

torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \
    llava/train/train_mem.py \
    --model_name_or_path /share/project/lianghuizhu/vicuna-13b-v0 \
    --data_path /share/project/lianghuizhu/science_qa/ScienceQA/data/scienceqa/llava_train_QCM-LEPA.json \
    --image_folder /share/project/lianghuizhu/science_qa/ScienceQA/data/scienceqa/images/train \
    --vision_tower /home/zhulianghui/ProjectC_ChatGPT/llava/reference/clip-vit-large-patch14 \
    --pretrain_mm_mlp_adapter /home/zhulianghui/ProjectC_ChatGPT/llava/reference/LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption-no_im_token.bin \
    --mm_vision_select_layer -2 \
    --output_dir ./checkpoints/llava-13b-finetune-8x40g-a100-sqa-no_im_start_end_token \
    --num_train_epochs 12 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 5000 \
    --save_total_limit 3 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --bf16 True \
    --fsdp "full_shard auto_wrap offload" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --run_name "llava-13b-finetune-8x40g-a100-sqa-no_im_start_end_token" \
    --lazy_preprocess True \
    --report_to mlflow

The --model_name_or_path /share/project/lianghuizhu/vicuna-13b-v0 is the checkpoint that applies the official vicuna delta on LLaMA-13b.

--pretrain_mm_mlp_adapter /home/zhulianghui/ProjectC_ChatGPT/llava/reference/LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption-no_im_token.bin is the projection layer provided in this repo that does not contains im token.

At last, I run the multi-gpu generation scripts in this repo to generate and gather the results.

haotian-liu commented 1 year ago

Hi @Unrealluver, thank you for your feedback. I run a test training process locally and do notice this performance drop using the latest code release. However, when I re-run the commit that I got the results during the development, I was able to reproduce the results. I am investigating this currently, and will let you know soon.

Sorry about this confusion in the released code, and thank you again for the feedback.

Unrealluver commented 1 year ago

@haotian-liu Thanks, I am waiting for your further reply.

haotian-liu commented 1 year ago

Hi @Unrealluver, it has been fixed now. There was an index not updated which caused the datasets with mixed image-text/text-only content having issues. Please pull the latest code base and it should work now. Thanks!

haotian-liu commented 1 year ago

Please also re-download this checkpoint, thanks. I used this checkpoint to verify the ScienceQA finetuning, not sure why I uploaded the wrong version. Sorry for the confusion again.