RWKV-CLIP: A Robust Vision-Language Representation Learner
Tiancheng Gu, Kaicheng Yang, Xiang An, Ziyong Feng, Dongnan Liu, Weidong Cai, Jiankang Deng
We introduce a diverse description generation framework that can leverage Large Language Models(LLMs) to synthesize and refine content from web-based texts, synthetic captions, and detection tags. Beneficial form detection tags, more semantic information can be introduced from images, which in turn further constrains LLMs and mitigates hallucinations.
We propose RWKV-CLIP, the first RWKV-driven vision-language representation learning model that combines the effective parallel training of transformers with the efficient inference of RNNs.
conda create -n rwkv_clip python=3.10 -y
conda activate rwkv_clip
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
pip install -U openmim
mim install mmcv-full==1.7.2
pip install -r requirements.txt
git clone https://github.com/deepglint/RWKV-CLIP.git
cd RWKV-CLIP
import os
import clip
import json
import torch
import warnings
from PIL import Image
from torch.nn import functional as F
from open_clip.transform import image_transform
from model_config.utils_notebook import load_model_configs
warnings.filterwarnings('ignore')
args = load_model_configs('model_config/RWKV_CLIP_B32.json') # model_config/RWKV_CLIP_B16.json
from model.utils import create_RWKV_Model
transform = image_transform(args.input_size, False)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Transfer to input type
image = transform(Image.open("figure/Diverse_description_generation_00.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
# Load model
RWKV_CLIP_model = create_RWKV_Model(args, model_weight_path = "Model_pretrained_weight.pt").to(device)
RWKV_CLIP_model.eval()
# Calculate score
with torch.no_grad():
image_features, text_features, logit_scale = RWKV_CLIP_model(image, text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print("Label probs: ", text_probs.tolist()) # Label probs: [[1., 0., 0.]]
The YFCC15M dataset we used is YFCC15M-DeCLIP, we download it from the repo, finally we successful donwload 15061515 image-text pairs.
The YFCC15M dataset we used to train RWKV-CLIP has been released in Hugging Face.
Model | Dataset | Download |
---|---|---|
RWKV-CLIP-B/32 | YFCC15M | 🤗ckpt | cfg |
RWKV-CLIP-B/32 | LAION10M | 🤗ckpt | cfg |
RWKV-CLIP-B/16 | LAION10M | 🤗ckpt | cfg |
RWKV-CLIP-B/32 | LAION30M | 🤗ckpt | cfg |
bash shell/train_RWKV_CLIP_B32_YFCC15M.sh
Evaluate zero shot cross-modal retireval
bash shell/test_zero_shot_retrieval.sh
Evaluate zero shot classification
bash shell/test_zero_shot_classificaiton.sh
Method | Model | MSCOCO R@1 | MSCOCO R@5 | MSCOCO R@10 | Flickr30k R@1 | Flickr30k R@5 | Flickr30k R@10 |
---|---|---|---|---|---|---|---|
CLIP | B/32 | 20.8/13.0 | 43.9/31.7 | 55.7/42.7 | 34.9/23.4 | 63.9/47.2 | 75.9/58.9 |
SLIP | B/32 | 27.7/18.2 | 52.6/39.2 | 63.9/51.0 | 47.8/32.3 | 76.5/58.7 | 85.9/68.8 |
DeCLIP | B/32 | 28.3/18.4 | 53.2/39.6 | 64.5/51.4 | 51.4/34.3 | 80.2/60.3 | 88.9/70.7 |
UniCLIP | B32 | 32.0/20.2 | 57.7/43.2 | 69.2/54.4 | 52.3/34.8 | 81.6/62.0 | 89.0/72.0 |
HiCLIP | B/32 | 34.2/20.6 | 60.3/43.8 | 70.9/55.3 | —— | —— | —— |
ALIP | B/32 | 46.8/29.3 | 72.4/54.4 | 81.8/65.4 | 70.5/48.9 | 91.9/75.1 | 95.7/82.9 |
Ours | B/32 | 50.3/34.0 | 76.2/60.9 | 85.2/71.7 | 76.0/57.6 | 94.7/82.3 | 97.6/88.7 |
Method | Model | CIFAR10 | CIFAR100 | Food101 | Pets | Flowers | SUN397 | Cars | DTD | Caltech101 | Aircraft | Imagenet | Average |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
CLIP | B/32 | 63.7 | 33.2 | 34.6 | 20.1 | 50.1 | 35.7 | 2.6 | 15.5 | 59.9 | 1.2 | 32.8 | 31.8 |
SLIP | B/32 | 50.7 | 25.5 | 33.3 | 23.5 | 49.0 | 34.7 | 2.8 | 14.4 | 59.9 | 1.7 | 34.3 | 30.0 |
FILIP | B/32 | 65.5 | 33.5 | 43.1 | 24.1 | 52.7 | 50.7 | 3.3 | 24.3 | 68.8 | 3.2 | 39.5 | 37.2 |
DeCLIP | B/32 | 66.7 | 38.7 | 52.5 | 33.8 | 60.8 | 50.3 | 3.8 | 27.7 | 74.7 | 2.1 | 43.2 | 41.3 |
HiCLIP | B/32 | 74.1 | 46.0 | 51.2 | 37.8 | 60.9 | 50.6 | 4.5 | 23.1 | 67.4 | 3.6 | 40.5 | 41.8 |
ALIP | B/32 | 83.8 | 51.9 | 45.4 | 30.7 | 54.8 | 47.8 | 3.4 | 23.2 | 74.1 | 2.7 | 40.3 | 41.7 |
Ours | B/32 | 79.8 | 55.1 | 50.6 | 37.6 | 57.1 | 54.0 | 4.1 | 24.6 | 77.1 | 4.0 | 44.3 | 44.4 |
This project is based on RWKV, VisionRWKV, RAM++, LLaMA-Factory, vllm, OFA, and open_clip, thanks for their works.
This project is released under the MIT license. Please see the LICENSE file for more information.
This project would not have been possible without the invaluable contributions of the following individuals, who have been instrumental in data scraping and collection:
Thank you to all the contributors for their hard work and dedication!
Contributor | Emial |
---|---|
Bin Qin | skyqin@gmail.com |
Lan Wu | bah-wl@hotmail.com |
Haiqiang Jiang | haiqiangjiang@deepglint.com |
Yuling Wu | yulingwu@deepglint.com |
If you find this repository useful, please use the following BibTeX entry for citation.
@misc{gu2024rwkvclip,
title={RWKV-CLIP: A Robust Vision-Language Representation Learner},
author={Tiancheng Gu and Kaicheng Yang and Xiang An and Ziyong Feng and Dongnan Liu and Weidong Cai and Jiankang Deng},
year={2024},
eprint={2406.06973},
archivePrefix={arXiv},
primaryClass={cs.CV}
}