This repository provides the official PyTorch implementation of our NeurIPS 2022 paper:
Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models
Authors: Manli Shu, Weili Nie, De-An Huang, Tom Goldstein, Anima Anandkumar, Chaowei Xiao
For more details, please check out our project page and paper.
This repository contains the implementation of TPT for image classification with a pre-trained CLIP. We consider 3 different initializations for test-time prompt tuning:
This implementation is for the single-GPU configuration.
To evaluate on ImageNet, ImageNet-V2, and ImageNet-Sketch (which has 1000 classes), you will need a GPU with more than (not including) 16GB memory. This codebase is tested on a GPU with 24GB memory. To evaluate other datasets (with less than a few hundred classes), a GPU with 16GB memory will work fine.
The code is tested on PyTorch 1.7.1.
We suggest downloading all datasets to a root directory (${data_root}
), and renaming the directory of each dataset as suggested in ${ID_to_DIRNAME}
in ./data/datautils.py
. This would allow you to evaluate multiple datasets within the same run.
If this is not feasible, you could evaluate different datasets separately, and change the ${data_root}
accordingly in the bash script.
For out-of-distribution generalization, we consider 5 datasets:
For cross-datasets generalization, we consider 10 datasets:
For cross-dataset generalization, we adopt the same train/val/test splits as CoOp. Please refer to this page, and look for download links of split_zhou_${dataset_name}.json
, and put the json files under ./data/data_splits/
.
We provide three bash scripts under ./scripts
. You can modify the paths and other args in the scripts.
An example to run TPT with CoOp initialization on out-of-distribution datasets:
bash ./scripts/test_coop.sh I/A/V/R/K.
The command line arg ${testsets}
can be multiple test datasets split by "/" (, which are stored under the same root dir ${data_root}
).
Note that for simplicity, we use set_id
to denote different datasets. A complete list of set_id
can be found in ${ID_to_DIRNAME}
in ./data/datautils.py
.
In each matrix $A$, $A{i, j}$ is the normalized relative improvement on the $j{th}$ dataset of using the prompt tuned on the $i$-th dataset. The value $A_{i, j}$ stands for how well a method trained on a source dataset $i$ performs on a target dataset $j$, in comparison with a zero-shot CLIP baseline (using a hand-crafted prompt). Thus, the higher, the better. The last row is the performance of TPT, which is not tuned on any source dataset. The last column summarizes the average improvement over 10 datasets, measuring the overall generalization ability across the 10 datasets.
Cross-dataset improvement normalized by the zero-shot baseline performance.
If you find our code useful or our work relevant, please consider citing:
@inproceedings{shu2022tpt,
author = {Manli, Shu and Weili, Nie and De-An, Huang and Zhiding, Yu and Tom, Goldstein and Anima, Anandkumar and Chaowei, Xiao},
title = {Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models},
booktitle = {NeurIPS},
year = {2022},
}
We thank the authors of CoOp/CoCoOp for their open-source implementation and instructions on data preparation.