chunmeifeng / SPRC

【ICLR 2024, Spotlight】Sentence-level Prompts Benefit Composed Image Retrieval
68 stars 3 forks source link

Sentence-level Prompts Benefit Composed Image Retrieval 【ICLR 2024, Spotlight】

PWC PWC

Prerequisites

The following commands will create a local Anaconda environment with the necessary packages installed.

conda create -n cir_sprc -y python=3.9
conda activate cir_sprc
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install -r requirements.txt

Data Preparation

To properly work with the codebase FashionIQ and CIRR datasets should have the following structure:

project_base_path
└───  fashionIQ_dataset
      └─── captions
            | cap.dress.test.json
            | cap.dress.train.json
            | cap.dress.val.json
            | ...

      └───  images
            | B00006M009.jpg
            | B00006M00B.jpg
            | B00006M6IH.jpg
            | ...

      └─── image_splits
            | split.dress.test.json
            | split.dress.train.json
            | split.dress.val.json
            | ...

└───  cirr_dataset  
       └─── train
            └─── 0
                | train-10108-0-img0.png
                | train-10108-0-img1.png
                | train-10108-1-img0.png
                | ...

            └─── 1
                | train-10056-0-img0.png
                | train-10056-0-img1.png
                | train-10056-1-img0.png
                | ...

            ...

       └─── dev
            | dev-0-0-img0.png
            | dev-0-0-img1.png
            | dev-0-1-img0.png
            | ...

       └─── test1
            | test1-0-0-img0.png
            | test1-0-0-img1.png
            | test1-0-1-img0.png 
            | ...

       └─── cirr
            └─── captions
                | cap.rc2.test1.json
                | cap.rc2.train.json
                | cap.rc2.val.json

            └─── image_splits
                | split.rc2.test1.json
                | split.rc2.train.json
                | split.rc2.val.json

Training

python src/blip_fine_tune_2.py \
   --dataset {'CIRR' or 'FashionIQ'} \
   --blip-model-name 'blip2_cir_align_prompt' \
   --num-epochs {'50' for CIRR, '30' for fashionIQ} \
   --num-workers 4 \
   --learning-rate {'1e-5' for CIRR, '2e-5' for fashionIQ} \
   --batch-size 128 \
   --transform targetpad \
   --target-ratio 1.25  \
   --save-training \
   --save-best \
   --validation-frequency 1 

Evaluation

python src/blip_validate.py \
   --dataset {'CIRR' or 'FashionIQ'} \
   --blip-model-name {trained model name} \
   --model-path {for path} 

CIRR Testing

python src/cirr_test_submission.py \
   --blip-model-name {trained model name} \
   --model-path {for path} \

Checkpoints

Onedrive: sprc_cirr.pt, sprc_fiq.pt

BaiduCloud: https://pan.baidu.com/s/18196NRV0Cdbn5uPc3LIgwg, password: t1at

models based on vit-l: sprc_cirr_vitl.pt, sprc_fiq_vitl.pt

Todo

code and pre-trained weights for rerank model

Acknowledgement

Our implementation is based on CLIP4Cir and LAVIS.