This repository contains the official implementation of the paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers". (NeurIPS 2024)
ProtoViT is a novel approach that combines Vision Transformers with prototype-based learning to create interpretable image classification models. Our implementation provides both high accuracy and explainability through learned prototypes.
These packages should be enough to reproduce our results. We add requirement.txt based on our conda environment for reference just in case.
Recommended GPU configurations:
git clone https://github.com/Henrymachiyu/ProtoViT.git
cd ProtoViT
pip install -r requirements.txt
#Download the dataset CUB_200_2011.tgz from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html
tar -xzf CUB_200_2011.tgz
Process the dataset:
For cropping data and training_test split, please carefully follow the instructions from the dataset. Sample code can be found in preprocess sample code that can crop and split data with Jupyter Notebook.
# Create directory structure
mkdir -p ./datasets/cub200_cropped/{train_cropped,test_cropped}
# Crop and split images using provided scripts
python your_own_scripts/crop_images.py # Uses bounding_boxes.txt
python your_own_scripts/split_dataset.py # Uses train_test_split.txt
#Put the cropped training images in the directory "./datasets/cub200_cropped/train_cropped/"
#Put the cropped test images in the directory "./datasets/cub200_cropped/test_cropped/"
# Augment training data
python img_aug.py
#this will create an augmented training set in the following directory:
#"./datasets/cub200_cropped/train_cropped_augmented/"
The official website for the dataset is:
Alternative dataset option available from:
settings.py
:# Dataset paths
data_path = "./datasets/cub200_cropped/"
train_dir = data_path + "train_cropped_augmented/"
test_dir = data_path + "test_cropped/"
train_push_dir = data_path + "train_cropped/"
python main.py
The corresponsing parameter settings for global and local analysis are saved in the analysis_settings.py
load_model_dir = 'saved model path'#'./saved_models/vgg19/003/'
load_model_name = 'model_name'#'14finetuned0.9230.pth'
save_analysis_path = 'saved_dir_rt'
img_name = 'prototype_vis_file'# 'img/'
test_data = "test_dir"
check_test_acc = False
check_list =['list of test images'] #"163_Mercedes-Benz SL-Class Coupe 2009/03123.jpg", Could be a list of images
To produce the reasoning plots:
We analyze nearest prototypes for specific test images and retrieve model reasoning process for predictions:
# this function provdes results for model's reasoning and local analysis
python local_analysis.py -gpuid 0
To produce the global analysis plots:
This following file finds nearest patches for each prototype to ensure the prototypes are semantically consistent across samples in train and test data:
python global_analysis.py -gpuid 0
To run the experiment, you would also need cleverhans
pip install cleverhans
All the parameters used for reproducing our results on location misalignment are stored in adv_settings.py
load_model_path = "."
test_dir = "./cub200_cropped/test_cropped"
model_output_dir = "." # dir for saving all the results
To run the adversarial attack and retrieve the results
cd ./spatial_alignment_test
python run_adv_test.py # as default, we ran experiment over entire test set
We provide checkpoints after projection and last layer finetuning on CUB-200-2011 dataset. | Model Version | Backbone | Resolution | Accuracy | Checkpoint |
---|---|---|---|---|---|
ProtoViT-T | DeiT-Tiny | 224×224 | 83.36% | Download | |
ProtoViT-S | DeiT-Small | 224×224 | 85.30% | Download | |
ProtoViT-CaiT | CaiT_xxs24 | 224×224 | 86.02% | Download |
This implementation is based on the timm, ProtoPNet repository and its variations. We thank the authors for their valuable work.
If you have any questions regarding the paper or implementations, please don't hesitate to email us: chiyu.ma.gr@dartmouth.edu
If you find this work useful in your research, please consider citing:
@article{ma2024interpretable,
title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
journal={arXiv preprint arXiv:2410.20722},
year={2024}
}
This project is licensed under the MIT License.