azshue / TPT

Test-time Prompt Tuning (TPT) for zero-shot generalization in vision-language models (NeurIPS 2022))
https://azshue.github.io/TPT/
MIT License
145 stars 17 forks source link

Test-Time Prompt Tuning (TPT) for zero-shot generalization in Vision-Language Models

This repository provides the official PyTorch implementation of our NeurIPS 2022 paper:

Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models
Authors: Manli Shu, Weili Nie, De-An Huang, Tom Goldstein, Anima Anandkumar, Chaowei Xiao

For more details, please check out our project page and paper.

Overview

This repository contains the implementation of TPT for image classification with a pre-trained CLIP. We consider 3 different initializations for test-time prompt tuning:

Prerequisites

Hardware

This implementation is for the single-GPU configuration.

To evaluate on ImageNet, ImageNet-V2, and ImageNet-Sketch (which has 1000 classes), you will need a GPU with more than (not including) 16GB memory. This codebase is tested on a GPU with 24GB memory. To evaluate other datasets (with less than a few hundred classes), a GPU with 16GB memory will work fine.

Environment

The code is tested on PyTorch 1.7.1.

Datasets

We suggest downloading all datasets to a root directory (${data_root}), and renaming the directory of each dataset as suggested in ${ID_to_DIRNAME} in ./data/datautils.py. This would allow you to evaluate multiple datasets within the same run.
If this is not feasible, you could evaluate different datasets separately, and change the ${data_root} accordingly in the bash script.

For out-of-distribution generalization, we consider 5 datasets:

For cross-datasets generalization, we consider 10 datasets:

For cross-dataset generalization, we adopt the same train/val/test splits as CoOp. Please refer to this page, and look for download links of split_zhou_${dataset_name}.json, and put the json files under ./data/data_splits/.

Run TPT

We provide three bash scripts under ./scripts. You can modify the paths and other args in the scripts.

An example to run TPT with CoOp initialization on out-of-distribution datasets:

bash ./scripts/test_coop.sh I/A/V/R/K.

The command line arg ${testsets} can be multiple test datasets split by "/" (, which are stored under the same root dir ${data_root}).
Note that for simplicity, we use set_id to denote different datasets. A complete list of set_id can be found in ${ID_to_DIRNAME} in ./data/datautils.py.

Main Results

Out-of-Distribution Generalization

| Method | ImageNet(IN) | IN-A | IN-V2 | IN-R | IN-Sketch | Average | OOD Average | |------------------|:--------:|:----------:|:-----------:|:----------:|:---------------:|:-------:|:-----------:| | [CLIP-RN50](https://arxiv.org/abs/2103.00020) | 58.16 | 21.83 | 51.41 | 56.15 | 33.37 | 44.18 | 40.69 | | [Ensembled prompt](https://arxiv.org/abs/2103.00020)| 59.81 | 23.24 | 52.91 | **60.72** | 35.48 | 46.43 | 43.09 | | [CoOp](https://arxiv.org/abs/2109.01134) | 63.33 | 23.06 | 55.40 | 56.60 | 34.67 | 46.61 | 42.43 | | [CoCoOp](https://arxiv.org/abs/2203.05557) | 62.81 | 23.32 | 55.72 | 57.74 | 34.48 | 46.81 | 42.82 | | TPT (ours) | 60.74 | 26.67 | 54.7 | 59.11 | 35.09 | 47.26 | 43.89 | | TPT + CoOp | **64.73** | **30.32** | **57.83** | 58.99 | **35.86** | **49.55** | **45.75** | | TPT + CoCoOp | 62.93 | 27.40 | 56.60 | 59.88 | 35.43 | 48.45 | 44.83 |


Cross-Dataset Generalization

In each matrix $A$, $A{i, j}$ is the normalized relative improvement on the $j{th}$ dataset of using the prompt tuned on the $i$-th dataset. The value $A_{i, j}$ stands for how well a method trained on a source dataset $i$ performs on a target dataset $j$, in comparison with a zero-shot CLIP baseline (using a hand-crafted prompt). Thus, the higher, the better. The last row is the performance of TPT, which is not tuned on any source dataset. The last column summarizes the average improvement over 10 datasets, measuring the overall generalization ability across the 10 datasets.

Cross-dataset improvement normalized by the zero-shot baseline performance.

Citation

If you find our code useful or our work relevant, please consider citing:

@inproceedings{shu2022tpt,
  author    = {Manli, Shu and Weili, Nie and De-An, Huang and Zhiding, Yu and Tom, Goldstein and Anima, Anandkumar and Chaowei, Xiao},
  title     = {Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models},
  booktitle = {NeurIPS},
  year      = {2022},
}

Acknowledgements

We thank the authors of CoOp/CoCoOp for their open-source implementation and instructions on data preparation.