TideDra / VL-RLHF

A RLHF Infrastructure for Vision-Language Models
Apache License 2.0
85 stars 5 forks source link
dpo llm lmm mllm rlhf vlm

VL-RLHF: A RLHF Infrastructure for Vision-Language Models




python version Static Badge GitHub Repo stars GitHub License

πŸŽ‰ News

Research and development on preference learning of Vision-Language Models (VLM, LVLM or MLLM) is difficult because there is currently no unified model architecture in the VLM community. The implementations of the state-of-the-art VLMs like LLaVA, Qwen-VL and internlm-xcomposer, are also in different styles, leading it hard to include them in a single training framework. VL-RLHF provides a perfect solution to abstract VLMs in a framework, and its features includes

πŸ€– Supported Models

βš™οΈ Supported Methods

πŸ› οΈ Installation

To install from source code (convenient for running the training and evaluation scripts), please run the following commands:

git clone https://github.com/TideDra/VL-RLHF.git
cd VL-RLHF
pip install -e .

We recommend to install FlashAttention for effective training and inference:

pip install flash-attn==2.5.8 --no-build-isolation

πŸš€ Training

Training Scripts

You can run the following command to launch DPO training of QwenVL-Chat using VLFeedback dataset:

#model weights should exist in ckpts/Qwen-VL-Chat
bash scripts/dpo_qwenvl.sh

Or run the python file directly:

accelerate launch --config_file accelerate_config/zero2.yaml --num_processes 8 \
        src/vlrlhf/dpo.py \
        --model_name_or_path ckpts/Qwen-VL-Chat \
        --output_dir ckpts/Qwen-VL-Chat-dpo/ \
        --dataset_name VLFeedback \
        --data_ratio 1.0 \
        --freeze_vision_tower True \
        --use_flash_attention_2 False \
        --use_lora True \
        --lora_r 64 \
        --lora_alpha 16 \
        --lora_dropout 0.05 \
        --lora_target_modules auto \
        --lora_bias "none" \
        --per_device_train_batch_size 4 \
        --per_device_eval_batch_size 4 \
        --gradient_accumulation_steps 8 \
        --num_train_epochs 1 \
        --adam_beta1 0.9 \
        --adam_beta2 0.98 \
        --adam_epsilon 1e-6 \
        --learning_rate 1e-5 \
        --weight_decay 0.05 \
        --warmup_ratio 0.1 \
        --lr_scheduler_type "cosine" \
        --gradient_checkpointing True \
        --bf16 True \
        --tf32 True \
        --remove_unused_columns False \
        --beta 0.1 \
        --max_length 1024 \
        --max_prompt_length 512 \
        --max_target_length 512 \
        --eval_strategy "steps" \
        --eval_steps 200 \
        --save_strategy "steps" \
        --save_steps 100 \
        --save_total_limit 10 \
        --logging_first_step False \
        --logging_steps 10 \
        --report_to wandb \
        --run_name  "bs256_lr1e-5" \
        --project_name "VL-RLHF" \
        --group_name "Qwen-VL-Chat-dpo"

To train other models with other methods, you can refer to the related scripts in scripts/ directory.

Please refer to arguments.md for detailed explanation of each arguments used in the scripts.

Data Preparation

VL-RLHF uses three arguments when processing the given dataset, which can be found in all the example training scripts. Please make sure they are properly set in the script before running it:

Customized Dataset

For methods that need comparison data, e.g. DPO, DDPO, KTO(paired), please prepare your json data in the following format:

[
    {
        "image":"example.jpg",
        "prompt":"Describe this image in detail.",
        "chosen":"This is a cat.",
        "rejected":"This is a dog."
    },
    ...
]

And set --dataset_name to plain_dpo in the training command.

For SFT, please prepare your conversation data in the following format:

[
    {
        "image":"example.jpg",
        "conversations":[
            {
                "from": "user",
                "value": "<prompt>",
            },
            {
                "from": "assistant",
                "value": "<answer>",
            },
            ...
        ]
    },
    ...
]

And set --dataset_name to vlquery_json in the training command.

Customized Model

You can easily add your own model to VL-RLHF framework by implementing some APIs. Please refer to CustomizedModel.md

πŸ“Š Evaluation

VL-RLHF supports to evaluate VLMs on popular multimodal benchmarks like MME, MMVet, Seedbench, MMBench and so on. Please refer to the Evaluation Guide for details.

πŸ₯‡ Performance

For reference, we report the performance of some models before and after DPO training on VLFeedback .

Model MMBench MMVet SEEDBench-Img MMMU MathVista
InternLM-Xcomposer2-VL-7b 76.37 46.5 74.19 40.33 56.7
InternLM-Xcomposer2-VL-7b-DPO 78.18 49.7 75.18 39.67 56.6
Qwen-VL-Chat 56.53 48.5 59.63 35.67 35.6
Qwen-VL-Chat-DPO 57.56 49.1 60.67 37.89 35.6
LLaVA-Next-Mistral-7b 67.70 43.8 71.7 37.00 35.1
LLaVA-Next-Mistral-7b-DPO 68.30 44.2 71.7 36.89 36.2
LLaVA-Next-Vicuna-7b 62.71 38.2 68.17 34.00 31.3
LLaVA-Next-Vicuna-7b-DPO 64.52 44.1 69.75 33.11 32.0

❀️References & Acknowledgements

We would like to express our gratitude to the following projects:

πŸŽ“Citation

If you find this work helpful, please consider to star🌟 this repo. Thanks for your support!

If you use VL-RLHF in your research, please use the following BibTeX entry.

@misc{vlrlhf,
  title = {VL-RLHF: A RLHF Infrastructure for Vision-Language Model},
  author = {Gongrui Zhang},
  howpublished = {\url{https://github.com/TideDra/VL-RLHF}},
  year = {2024}
}