AngelosNal / Vision-DiffMask

Official PyTorch implementation of Vision DiffMask, a post-hoc interpretation method for vision models.
MIT License
27 stars 3 forks source link
computer-vision deep-learning interpretability pytorch vision-transformer

VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking

:page_with_curl: [Paper]  :rocket: [Demo]   :floppy_disk: [Checkpoints]

This repository contains the official PyTorch implementation of the paper "VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking" by Angelos Nalmpantis, Apostolos Panagiotopoulos, John Gkountouras, Konstantinos Papakostas and Wilker Aziz (CVPRW XAI4CV 2023)

Overview

Vision DiffMask is a post-hoc interpretation method for vision tasks. Given a pre-trained model, it predicts the minimal subset of the input required to maintain the original output distribution. Currently, only Vision Transformer (ViT) for image classification is supported.

Alt text

Setup

We provide a conda environment for the installation of the required packages.

conda env create -f environment.yml

Project Structure

The project is organized in the following way:

.                                                       
├── code                                                             
│   ├── attributions/                                                                           
│   ├── datamodules
│   │   ├── base.py 
│   │   ├── image_classification.py
│   │   ├── transformations.py
│   │   ├── utils.py
│   │   └── visual_qa.py
│   ├── eval_base.py
│   ├── main.py
│   ├── models
│   │   ├── classification.py
│   │   ├── gates.py
│   │   ├── interpretation.py
│   │   └── utils.py
│   ├── train_base.py
│   └── utils
│       ├── distributions.py
│       ├── getters_setters.py
│       ├── metrics.py
│       ├── optimizer.py
│       └── plot.py
├── experiments/

Training

To train a Vision DiffMask model on CIFAR-10 based on the Vision Transformer, use the following command:

python code/main.py --enable_progress_bar --num_epochs 20 --base_model ViT --dataset CIFAR10 \
                    --from_pretrained tanlq/vit-base-patch16-224-in21k-finetuned-cifar10

You can refer to the next section for a full list of launch options.

Launch Arguments

Vision DiffMask When training Vision DiffMask, the following launch options can be used: ``` Arguments: --enable_progress_bar Whether to enable the progress bar (NOT recommended when logging to file). --num_epochs NUM_EPOCHS Number of epochs to train. --seed SEED Random seed for reproducibility. --sample_images SAMPLE_IMAGES Number of images to sample for the mask callback. --log_every_n_steps LOG_EVERY_N_STEPS Number of steps between logging media & checkpoints. --base_model {ViT} Base model architecture to train. --from_pretrained FROM_PRETRAINED The name of the pretrained HF model to load. --dataset {MNIST,CIFAR10,CIFAR10_QA,toy} The dataset to use. Vision DiffMask: --alpha ALPHA Initial value for the Lagrangian --lr LR Learning rate for DiffMask. --eps EPS KL divergence tolerance. --no_placeholder Whether to not use placeholder --lr_placeholder LR_PLACEHOLDER Learning for mask vectors. --lr_alpha LR_ALPHA Learning rate for lagrangian optimizer. --mul_activation MUL_ACTIVATION Value to multiply gate activations. --add_activation ADD_ACTIVATION Value to add to gate activations. --weighted_layer_distribution Whether to use a weighted distribution when picking a layer in DiffMask forward. Data Modules: --data_dir DATA_DIR The directory where the data is stored. --batch_size BATCH_SIZE The batch size to use. --add_noise Use gaussian noise augmentation. --add_rotation Use rotation augmentation. --add_blur Use blur augmentation. --num_workers NUM_WORKERS Number of workers to use for data loading. Visual QA: --class_idx CLASS_IDX The class (index) to count. --grid_size GRID_SIZE The number of images per row in the grid. ```
Training the base model When training the base model (usually not needed as we support pretrained models from HuggingFace), the following launch options can be used: ``` Arguments: --checkpoint CHECKPOINT Checkpoint to resume the training from. --enable_progress_bar Whether to show progress bar during training. NOT recommended when logging to files. --num_epochs NUM_EPOCHS Number of epochs to train. --seed SEED Random seed for reproducibility. --base_model {ViT,ConvNeXt} Base model architecture to train. --from_pretrained FROM_PRETRAINED The name of the pretrained HF model to fine-tune from. --dataset {MNIST,CIFAR10,CIFAR10_QA,toy} The dataset to use. Classification Model: --optimizer {AdamW,RAdam} The optimizer to use to train the model. --weight_decay WEIGHT_DECAY The optimizer's weight decay. --lr LR The initial learning rate for the model. Data Modules: --data_dir DATA_DIR The directory where the data is stored. --batch_size BATCH_SIZE The batch size to use. --add_noise Use gaussian noise augmentation. --add_rotation Use rotation augmentation. --add_blur Use blur augmentation. --num_workers NUM_WORKERS Number of workers to use for data loading. Visual QA: --class_idx CLASS_IDX The class (index) to count. --grid_size GRID_SIZE The number of images per row in the grid. ```
Evaluating the base model When evaluating the base model, the following launch options can be used: ``` Arguments: --checkpoint CHECKPOINT Checkpoint to resume the training from. --enable_progress_bar Whether to show progress bar during training. NOT recommended when logging to files. --seed SEED Random seed for reproducibility. --base_model {ViT,ConvNeXt} Base model architecture to train. --from_pretrained FROM_PRETRAINED The name of the pretrained HF model to fine-tune from. --dataset {MNIST,CIFAR10,CIFAR10_QA,toy} The dataset to use. Data Modules: --data_dir DATA_DIR The directory where the data is stored. --batch_size BATCH_SIZE The batch size to use. --add_noise Use gaussian noise augmentation. --add_rotation Use rotation augmentation. --add_blur Use blur augmentation. --num_workers NUM_WORKERS Number of workers to use for data loading. Visual QA: --class_idx CLASS_IDX The class (index) to count. --grid_size GRID_SIZE The number of images per row in the grid. ```

//: # ()

//: # ()

Contributing

This project is licensed under the MIT license.

Acknowledgements

Vision DiffMask is an adaptation of DiffMask in the vision domain. Parts of the code are heavilty inspired from its original PyTorch implementation.

//: # ()

Citation

If you use this code or find our work otherwise useful, please consider citing our paper:

@inproceedings{nalmpantis2023vision,
  title={VISION DIFFMASK: Faithful Interpretation of Vision Transformers with Differentiable Patch Masking},
  author={Nalmpantis, Angelos and Panagiotopoulos, Apostolos and Gkountouras, John and Papakostas, Konstantinos and Aziz, Wilker},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={3755--3762},
  year={2023}
}