Chiangsonw / CaLa

The official code of "CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval"
MIT License
4 stars 0 forks source link

CaLa

PWC PWC

CaLa(ACM SIGIR 2024) is a new composed image retrieval framework, considering two complementary associations in the task. CaLa presents TBIA(text-based image alignment) and CTR(complementary text reasoning) for augmenting composed image retrieval.

We highlight the contributions of this paper as follows:

• We present a new thinking of composed image retrieval,the annotated triplet is viewed as a graph node, and two complementary association clues are disclosed to enhance the composed image retrieval.

• A hinge-based attention and twin-attention-based visual compositor are proposed to effectively impose the new associations into the network learning.

• Competitive Performance on CIRR and FashionIQ benchmarks. CaLa can benefit several baselines with different backbones and architectures, revealing it is a widely beneficial module for composed image retrieval.

More details can be found at our paper: CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval

This is the workflow of our CaLa framework.

News

Models and Weights

Usage

Prerequisites

The following commands will create a local Anaconda environment with the necessary packages installed.

conda create -n cala -y python=3.8
conda activate cala
conda install -y -c pytorch pytorch=1.11.0 torchvision=0.12.0
conda install -y -c anaconda pandas=1.4.2
pip install comet-ml==3.21.0
pip install git+https://github.com/openai/CLIP.git
pip install salesforce-lavis

Data Preparation

To properly work with the codebase FashionIQ and CIRR datasets should have the following structure:

project_base_path
└───  CaLa
      └─── src
            | blip_fine_tune.py
            | data_utils.py
            | utils.py
            | ...

└───  fashionIQ_dataset
      └─── captions
            | cap.dress.test.json
            | cap.dress.train.json
            | cap.dress.val.json
            | ...

      └───  images
            | B00006M009.jpg
            | B00006M00B.jpg
            | B00006M6IH.jpg
            | ...

      └─── image_splits
            | split.dress.test.json
            | split.dress.train.json
            | split.dress.val.json
            | ...

└───  cirr_dataset  
       └─── train
            └─── 0
                | train-10108-0-img0.png
                | train-10108-0-img1.png
                | train-10108-1-img0.png
                | ...

            └─── 1
                | train-10056-0-img0.png
                | train-10056-0-img1.png
                | train-10056-1-img0.png
                | ...

            ...

       └─── dev
            | dev-0-0-img0.png
            | dev-0-0-img1.png
            | dev-0-1-img0.png
            | ...

       └─── test1
            | test1-0-0-img0.png
            | test1-0-0-img1.png
            | test1-0-1-img0.png 
            | ...

       └─── cirr
            └─── captions
                | cap.rc2.test1.json
                | cap.rc2.train.json
                | cap.rc2.val.json

            └─── image_splits
                | split.rc2.test1.json
                | split.rc2.train.json
                | split.rc2.val.json

Adjustments for dependencies

For finetuning blip2 encoderds, you need to comment out this code in lavis within your conda enviroment.

# In lavis/models/blip2_models/blip2_qformer.py line 367
# @torch.no_grad() # commemt out this line.

Comment out this code to calculate the gradient of the blip2-model to update the parameters.

For finetuning clip encoders, you need to replace with these codes in the clip packages, thus RN50x4 features can interact with Qformers.

# Replace CLIP/clip/models.py line 152-154 with the following codes.
152#    x = self.attnpool(x)
153#    
154#    return x

152#    y=x 
153#    x = self.attepool(x)
154#
155#    return x,y

Training

# cala finetune 
CUDA_VISIBLE_DEVICES='GPU_IDs' python src/blip_fine_tune.py --dataset {'CIRR' or 'FashionIQ'} \
    --num-epochs 30 --batch-size 64 \
    --max-epoch 15 --min-lr 0 \
    --learning-rate 5e-6 \
    --transform targetpad --target-ratio 1.25 \
    --save-training --save-best --validation-frequency 1 \
    --encoder {'both' or 'text' or 'multi'} \
    --encoder-arch {clip or blip2} \
    --cir-frame {sum or artemis} \
    --tac-weight 0.45 \
    --hca-weight 0.1 \
    --embeds-dim {640 for clip and 768 for blip2} \
    --model-name {RN50x4 for clip and None for blip} \
    --api-key {Comet-api-key} \
    --workspace {Comet-workspace} \
    --experiment-name {Comet-experiment-name} \

CIRR Testing

CUDA_VISIBLE_DEVICES='GPU_IDs' python src/cirr_test_submission_blip2.py --submission-name {cirr_submission} \
    --combining-function {sum or artemis} \
    --blip2-textual-path {saved_blip2_textual.pt} \
    --blip2-multimodal-path {saved_blip2_multimodal.pt} \
    --blip2-visual-path {saved_blip2_visual.pt} 
python src/validate.py 
    --dataset {'CIRR' or 'FashionIQ'} \
    --combining-function {'combiner' or 'sum'} \
    --combiner-path {path to trained Combiner} \
    --projection-dim 2560 \
    --hidden-dim 5120 \
    --clip-model-name RN50x4 \
    --clip-model-path {path-to-fine-tuned-CLIP} \
    --target-ratio 1.25 \
    --transform targetpad

Reference

If you use CaLa in your research, please cite it by the following BibTeX entry:

@article{jiang2024cala,
  title={CaLa: Complementary Association Learning for Augmenting Composed Image Retrieval},
  author={Jiang, Xintong and Wang, Yaxiong and Li, Mengjian and Wu, Yujiao and Hu, Bingwen and Qian, Xueming},
  journal={arXiv preprint arXiv:2405.19149},
  year={2024}
}

Acknowledgement

Our implementation is based on CLIP4Cir and LAVIS.