This codebase provides a Pytorch implementation for the paper: Are Vision Transformers Robust to Spurious Correlations? .
Deep neural networks may be susceptible to learning spurious correlations that hold on average but not in atypical test samples. As with the recent emergence of vision transformer (ViT) models, it remains underexplored how spurious correlations are manifested in such architectures. In this paper, we systematically investigate the robustness of vision transformers to spurious correlations on three challenging benchmark datasets and compare their performance with popular CNNs. Our study reveals that when pre-trained on a sufficiently large dataset, ViT models are more robust to spurious correlations than CNNs. Key to their success is the ability to generalize better from the examples where spurious correlations do not hold. Further, we perform extensive ablations and experiments to understand the role of the self-attention mechanism in providing robustness under spuriously correlated environments. We hope that our work will inspire future research on further understanding the robustness of ViT models.
Our experiments are conducted on Ubuntu Linux 20.04 with Python 3.9 and Pytorch 1.6. Besides, the following packages are required to be installed:
In our experiments, for ViT models we use the pre-trained checkpoints provided with the timm library. Pre-trained checkpoints for BiT models can be downloaded following the instructions in the official repo. Please download place the checkpoints for BiTs in bit_pretrained_models
folder.
In our study, we use the following challenging benchmarks :
datasets/celebA/celebA_split.csv
, and after downloading the dataset, please place the images in the folder of datasets/celebA/img_align_celeba/
. datasets/celebA_dataset.py
provides the dataloader for CelebA datasets and OOD datasets.To run the experiments, you need to first download and place the pretrained model checkpoints and datasets in the specificed folders as instructed in Pre-trained Checkpoints and Datasets. We provide the following commands and general descriptions for related files.
datasets/waterbirds_dataset.py
: provides the dataloader for Waterbirds dataset.
The code expects the following files/folders in the [root_dir]/datasets
directory:
waterbird_complete95_forest2water2/
You can download a tarball of this dataset from here. The Waterbirds dataset can also be accessed through the WILDS package, which will automatically download the dataset.
To train ViT model (ViT-B_16) on Waterbirds dataset, run the following command:
python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 500 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384
To train ViT model (ViT-S_16) on Waterbirds dataset, run the following command:
python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 100 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384
Similarly, sample command to run BiT model on Watervirds dataset:
python train.py --name waterbirds_exp --model_arch BiT --model_type BiT-M-R50x1 --dataset waterbirds --learning_rate 0.003--batch_split 1 --img_size 384
Notes for some of the arguments:
--name
: Name to identify the checkpoints. Users are welcome to use other names for convenience.--model_arch
: Model architecture to be used for training. Users need to specify ViT
for Vision Transformers or BiT
for Big-Transfer models.--model_type
: Model variant to be used for training. Please check the table below.--warmup_steps
: Specifies the number of warmup steps used for training ViT models. This is set as 500 for all ViT models.--num_steps
: Specifies the total number of global steps used for training ViT models. For ViT-S_16 and ViT-Ti_16, this is set as 1000 whereas for ViT-B_16 set this as 2000.--batch_split
: The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of batch_split
.Model | model_arch | model_type | #params |
---|---|---|---|
ViT-B/16 | ViT | ViT-B_16 | 86.1 M |
ViT-S/16 | ViT | ViT-S_16 | 21.8 M |
ViT-Ti/16 | ViT | ViT-Ti_16 | 5.6 M |
BiT-M-R50x3 | BiT | BiT-M-R50x3 | 211 M |
BiT-M-R101x1 | BiT | BiT-M-R101x1 | 42.5 M |
BiT-M-R50x1 | BiT | BiT-M-R50x1 | 23.5 M |
To generate accuracy metrics for ViT model(ViT-S_16) on train and test data (worst-group accuracy), run the following command :
python evaluate.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint
Notes for some of the arguments:
--checkpoint_dir
: Model checkpoint fine-tuned on waterbirds to be used for inference. If not provided, then it automatically searches for the model checkpoint in output/[name]/[model_arch]/[model_type]
directory.To generate consistency measure, users need to first download the evaluation dataset from here and place the images in [root_dir]/datasets/waterbird_bg
directory. For ViT model (ViT-S_16), run the following command:
python waterbirds_consistency.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --checkpoint_dir model_checkpoint --batch_size 32
To generate the OOD dataset, users need to run datasets/generate_placebg.py
which subsamples background images of specific types as the OOD data. You can simply run python generate_placebg.py
to generate the OOD dataset, and it will be stored as datasets/ood_datasets/placesbg/
. Note: Before the generation of OOD dataset, users need to download and change the path of CUB dataset and Places dataset.
To obtain spurious OOD evaluation for for ViT model (ViT-S_16), run the following command:
python ood_eval.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --id_dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint
Some parts of the codebase are adapted from GDRO, Spurious_OOD, big_transfer and ViT-pytorch.
@misc{ghosal2022vision,
title={Are Vision Transformers Robust to Spurious Correlations?},
author={Soumya Suvra Ghosal and Yifei Ming and Yixuan Li},
year={2022},
eprint={2203.09125},
archivePrefix={arXiv},
primaryClass={cs.CV}
}