deeplearning-wisc / vit-spurious-robustness

24 stars 4 forks source link

Are Vision Transformers Robust to Spurious Correlations?

This codebase provides a Pytorch implementation for the paper: Are Vision Transformers Robust to Spurious Correlations? .

Abstract

Deep neural networks may be susceptible to learning spurious correlations that hold on average but not in atypical test samples. As with the recent emergence of vision transformer (ViT) models, it remains underexplored how spurious correlations are manifested in such architectures. In this paper, we systematically investigate the robustness of vision transformers to spurious correlations on three challenging benchmark datasets and compare their performance with popular CNNs. Our study reveals that when pre-trained on a sufficiently large dataset, ViT models are more robust to spurious correlations than CNNs. Key to their success is the ability to generalize better from the examples where spurious correlations do not hold. Further, we perform extensive ablations and experiments to understand the role of the self-attention mechanism in providing robustness under spuriously correlated environments. We hope that our work will inspire future research on further understanding the robustness of ViT models.

Required Packages

Our experiments are conducted on Ubuntu Linux 20.04 with Python 3.9 and Pytorch 1.6. Besides, the following packages are required to be installed:

Pre-trained Checkpoints

In our experiments, for ViT models we use the pre-trained checkpoints provided with the timm library. Pre-trained checkpoints for BiT models can be downloaded following the instructions in the official repo. Please download place the checkpoints for BiTs in bit_pretrained_models folder.

Datasets

In our study, we use the following challenging benchmarks :

Quick Start

To run the experiments, you need to first download and place the pretrained model checkpoints and datasets in the specificed folders as instructed in Pre-trained Checkpoints and Datasets. We provide the following commands and general descriptions for related files.

WaterBirds

You can download a tarball of this dataset from here. The Waterbirds dataset can also be accessed through the WILDS package, which will automatically download the dataset.

To train ViT model (ViT-B_16) on Waterbirds dataset, run the following command:

python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 500 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384

To train ViT model (ViT-S_16) on Waterbirds dataset, run the following command:

python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 100 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384

Similarly, sample command to run BiT model on Watervirds dataset:

python train.py --name waterbirds_exp --model_arch BiT --model_type BiT-M-R50x1 --dataset waterbirds --learning_rate 0.003--batch_split 1 --img_size 384

Notes for some of the arguments:

Model model_arch model_type #params
ViT-B/16 ViT ViT-B_16 86.1 M
ViT-S/16 ViT ViT-S_16 21.8 M
ViT-Ti/16 ViT ViT-Ti_16 5.6 M
BiT-M-R50x3 BiT BiT-M-R50x3 211 M
BiT-M-R101x1 BiT BiT-M-R101x1 42.5 M
BiT-M-R50x1 BiT BiT-M-R50x1 23.5 M

To generate accuracy metrics for ViT model(ViT-S_16) on train and test data (worst-group accuracy), run the following command :

python evaluate.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint

Notes for some of the arguments:

To generate consistency measure, users need to first download the evaluation dataset from here and place the images in [root_dir]/datasets/waterbird_bg directory. For ViT model (ViT-S_16), run the following command:

python waterbirds_consistency.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --checkpoint_dir model_checkpoint --batch_size 32

Spurious OOD evaluation

To generate the OOD dataset, users need to run datasets/generate_placebg.py which subsamples background images of specific types as the OOD data. You can simply run python generate_placebg.py to generate the OOD dataset, and it will be stored as datasets/ood_datasets/placesbg/. Note: Before the generation of OOD dataset, users need to download and change the path of CUB dataset and Places dataset.

To obtain spurious OOD evaluation for for ViT model (ViT-S_16), run the following command:

python ood_eval.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --id_dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint

References

Some parts of the codebase are adapted from GDRO, Spurious_OOD, big_transfer and ViT-pytorch.

For bibtex citation

@misc{ghosal2022vision,
      title={Are Vision Transformers Robust to Spurious Correlations?}, 
      author={Soumya Suvra Ghosal and Yifei Ming and Yixuan Li},
      year={2022},
      eprint={2203.09125},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}