linzhiqiu / CLIP-FlanT5

Training code for CLIP-FlanT5
Apache License 2.0
15 stars 1 forks source link

πŸ’¬ CLIP-FlanT5: Multimodal Encoder-Decoder Language Model for VQAScore

Evaluating text-to-image generation using VQAScore with CLIP-FlanT5! This codebase contains the training code for CLIP-FlanT5.

[Project Page] [Code for evaluation] [Data] [Model Zoo]

Evaluating Text-to-Visual Generation with Image-to-Text Generation (Arxiv) [Paper]
Zhiqiu Lin, Deepak Pathak, Baiqi Li, Jiayao Li, Xide Xia, Graham Neubig, Pengchuan Zhang*, Deva Ramanan*

Release

Usage and License Notices: The data and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna, FlanT5, and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.

Contents

Install

We follow LLaVA-1.5 for installation. If you already installed the environment from LLaVA, there is no need to create a new environment.

  1. Clone this repository and navigate to CLIP-FlanT5 folder

    git clone https://github.com/linzhiqiu/CLIP-FlanT5.git
    cd CLIP-FlanT5
  2. Install Package (if you do not have llava environment installed already)

    conda create -n llava python=3.10 -y
    conda activate llava
    pip install --upgrade pip  # enable PEP 660 support
    pip install -e .
  3. Install additional packages for training cases (if you do not have llava environment installed already)

    pip install -e ".[train]"
    pip install flash-attn --no-build-isolation
  4. Install huggingface_hub

    python -m pip install huggingface_hub

Upgrade to latest code base

git pull
pip install -e .

CLIP-FlanT5 Weights

Please check out our Model Zoo for public CLIP-FlanT5 checkpoints.

Train

CLIP-FlanT5 consists of two stages: (1) feature alignment stage: use LLaVA-1.5 558K subset of the LAION-CC-SBU dataset to connect a frozen pretrained vision encoder to a frozen FlanT5; (2) vqa finetuning stage: use 150K LLaVA-chat data and around 515K VQA data from academic-oriented tasks.

CLIP-FlanT5 is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the per_device_train_batch_size and increase the gradient_accumulation_steps accordingly. Always keep the global batch size the same: per_device_train_batch_size x gradient_accumulation_steps x num_gpus.

Hyperparameters

We use a similar set of hyperparameters as LLaVA-1.5 in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.

  1. Pretraining
Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
CLIP-FlanT5 256 1e-2 1 2048 0
  1. Finetuning
Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
CLIP-FlanT5 96 2e-5 1 2048 0

Download FlanT5 checkpoints (automatically)

The base model FlanT5, which is a strong QA model developed by Google, will be downloaded automatically when you run our provided training scripts. No action is needed.

Pretrain (feature alignment)

You can download the 558K subset of the LAION-CC-SBU dataset with BLIP captions use in the LLaVA-1.5 paper here and unzip/put them under "playground/data/LLaVA-Pretrain". The final folder structure should look like:

playground/data/
β”œβ”€β”€ LLaVA-Pretrain
β”‚   └── blip_laion_cc_sbu_558k.json
β”‚   └── images

Pretrain takes around 5 hours for CLIP-FlanT5-XXL on 8x A100 (80G) using the image resolution of 336px. It takes around 2 hours for LLaVA-v1.5-7B.

Training script with DeepSpeed ZeRO-2: clip-flant5-xxl-stage-1.sh.

If you are using slurm environment, you can also use the slurm script (by changing the default partition name to your own #SBATCH --partition={your_own_partition}) provided in clip-flant5-xxl-stage-1.slurm.

Finetune (training for VQA)

  1. Prepare data

We flattened the LLaVA-1.5 mixture of data (please download from llava_v1_5_mix665k_flattened_multi_turn.json), and also download the images from constituting datasets:

After downloading all of them, organize the data as follows in ./playground/data,

β”œβ”€β”€ coco
β”‚   └── train2017
β”œβ”€β”€ gqa
β”‚   └── images
β”œβ”€β”€ ocr_vqa
β”‚   └── images
β”œβ”€β”€ textvqa
β”‚   └── train_images
└── vg
    β”œβ”€β”€ VG_100K
    └── VG_100K_2
  1. Start training!

You may download the stage-1 pretrained projectors in Model Zoo.

Stage-2 VQA training takes around 80 hours for CLIP-FlanT5-XXL on 8x A100 (80G), due to the increased resolution to 336px and flattening the multi-turn conversations into single-turn. It takes around 60 hours for CLIP-FlanT5-XL on 8x A100 (40G).

Training script with DeepSpeed ZeRO-3: clip-flant5-xxl.sh. Optionally, if you use slurm, then you may use clip-flant5-xxl.slurm (make sure to change the default slurm partition).

New options to note:

Evaluation

Please refer to the t2v_metrics repo which contains evaluation code for VQAScore using CLIP-FlanT5.

Citation

If you find it useful for your research and applications, please cite using this BibTeX:

@article{lin2024evaluating,
    title={Evaluating Text-to-Visual Generation with Image-to-Text Generation},
    author={Lin, Zhiqiu and Pathak, Deepak and Li, Baiqi and Li, Jiayao and Xia, Xide and Neubig, Graham and Zhang, Pengchuan and Ramanan, Deva},
    journal={arXiv preprint arXiv:2404.01291},
    year={2024}
}

Acknowledgement