[PromptKD: Unsupervised Prompt Distillation for Vision-Language Models]()
Zheng Li, Xiang Li#, Xinyi Fu, Xin Zhang, Weiqiang Wang, Shuo Chen, Jian Yang#.
Nankai University, Ant Group, RIKEN
CVPR 2024
[Paper] [Project Page] [Poster] [中文解读]
In this paper, we introduce an unsupervised domain prompt distillation framework, which aims to transfer the knowledge of a larger teacher model to a lightweight target model through prompt-driven imitation using unlabeled domain images.
To our best knowledge, we are the first to (1) perform unsupervised domain-specific prompt-driven knowledge distillation for CLIP, and (2) establish a practical pre-storing mechanism of text features as shared class vectors between teacher and student.
(1). A novel two-stage unsupervised prompt distillation framework for Vision-Language Models.
(2). Reuse high-quality teacher text features instead of training the student's own text encoder.
(3). Distillation on large amounts of unlabeled domain images using soft labels provided by the teacher.
(4). PromptKD outperforms all existing prompt learning methods on 11 diverse recognition datasets.
Results reported below show accuracy for base and novel classes for across 11 recognition datasets averaged over 3 seeds.
Create the environment and install Dassl.pytorch library. Please follow the instructions detailed in INSTALL.md.
(1) Pre-train your own large teacher CLIP model (See below) or (2) use our publicly released pre-trained teacher ViT-L/14 CLIP models. (Highly Recommended)
Our pre-trained teacher models are publicly available at [Baidu Yun] [TeraBox] [Google Cloud]
(Note that due to cloud space limitations, we only provide a limited number of models in Google Cloud. Sorry.)
After obtaining the teacher model, unzip these files and place the model in the ./teacher_model
folder.
The accuracy of each teacher model is shown in Tables 10 and 11 in the supplementary material of the paper.
Download the original ViT-B/16 and ViT-L/14 CLIP model weights from the official OpenAI website. Then place these models in the ./clip
folder.
[ViT-B/16 CLIP] [ViT-L/14 CLIP]
Prepare the dataset. Please follow the instructions detailed in DATASETS.md.
In our paper, we default use PromptSRC to pre-train our ViT-L/14 CLIP teacher model. We have already provided the config file in configs/trainers/PromptSRC/vit_l14_c2_ep20_batch8_4+4ctx.yaml
If you want to train your own teacher model, first you should change scripts/promptsrc/base2new_train.sh line 11 CFG=vit_b16_c2_ep20_batch4_4+4ctx
to vit_l14_c2_ep20_batch8_4+4ctx
.
Then follow the instructions listed in docs/PromptSRC.md
and run the script.
Important Note:
The accuracy of your own teacher model may vary depending on your computing environment. To ensure that your teacher model is adequate for distillation, please refer to Appendix Table 10 to check whether your model achieves appropriate accuracy.
If your teacher model cannot achieve the corresponding accuracy or cannot be trained due to computational constraints, I highly recommend that you use our publicly available pre-trained models for distillation.
The base-to-novel experimental settings are provided in the config file at configs/trainers/PromptKD/vit_b16_c2_ep20_batch8_4+4ctx.yaml
. You can modify the hyper-parameters in this config file according to your needs.
Change the dataset path in scripts/promptkd/base2new_train.sh line 4
to your current path.
Run the commands below to train PromptKD on the specified dataset.
For example:
# dataset=imagenet, seed=1
sh scripts/promptkd/base2new_train.sh imagenet 1
# seed=2
sh scripts/promptkd/base2new_train.sh imagenet 2
# seed=3
sh scripts/promptkd/base2new_train.sh imagenet 3
# dataset=caltech101, seed=1
sh scripts/promptkd/base2new_train.sh caltech101 1
output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed_${SEED}
.The cross-dataset experimental settings are provided in the config file at configs/trainers/PromptKD/vit_b16_c2_ep20_batch8_4+4ctx_cross_datasets.yaml
. You can modify the hyper-parameteres in this config file according to your needs.
Change the dataset path in scripts/promptkd/xd_train.sh line 4
to your current path.
Run the commands below to train PromptKD on the specified dataset.
For example:
# dataset=caltech101, seed=1
sh scripts/promptkd/xd_train.sh caltech101 1
# seed=2
sh scripts/promptkd/xd_train.sh caltech101 2
# seed=3
sh scripts/promptkd/xd_train.sh caltech101 3
# dataset=oxford_pets, seed=1
sh scripts/promptkd/xd_train.sh oxford_pets 1
output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
.Here we provide the pretrained student models and complete training logs using 64-shots and 0-shots (i.e., full dataset) on the ImageNet dataset for your reference. Please refer to [Releases Part].
If you have any questions, you can submit an issue on GitHub, leave a message on Zhihu Article (if you can speak Chinese), or contact me by email (zhengli97[at]qq.com).
If you find our paper or repo helpful for your research, please consider citing our paper and giving this repo a star⭐. Thank you!
@inproceedings{li2024promptkd,
title={Promptkd: Unsupervised prompt distillation for vision-language models},
author={Li, Zheng and Li, Xiang and Fu, Xinyi and Zhang, Xin and Wang, Weiqiang and Chen, Shuo and Yang, Jian},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={26617--26626},
year={2024}
}
Our code is based on PromptSRC, MaPLe, Co-CoOp and CoOp repository. We thank the authors for releasing their code.