suinleelab / vit-shapley

26 stars 6 forks source link

ViT-Shapley

Shapley values are a theoretically grounded model explanation approach, but their exponential computational cost makes them difficult to use with large deep learning models. This package implements ViT-Shapley, an approach that makes Shapley values practical for vision transformer (ViT) models. The key idea is to learn an amortized explainer model that generates explanations in a single forward pass.

The high-level workflow for using ViT-Shapley is the following:

  1. Obtain your initial ViT model
  2. If your model was not trained to acommodate held-out image patches, fine-tune it with random masking
  3. Train an explainer model using ViT-Shapley's custom loss function (often by fine-tuning parameters of the original ViT)

Please see our paper here for more details, as well as the work that ViT-Shapley builds on (KernelSHAP, FastSHAP).

Installation

git clone https://github.com/chanwkimlab/vit-shapley.git
cd vit-shapley
pip install -r requirements.txt

Training

Commands for training and testing the models are available in the files under scripts directory.

Benchmarking

  1. Run notebooks/2_1_benchmarking.ipynb to obtain results.
  2. Run notebooks/2_2_ROAR.ipynb to run retraining-based ROAR benchmarking.
  3. Run notebooks/3_plotting.ipynb to plot the results.

Datasets

Model weights

Pretrained model weights for vit-base models are available here.

Demo

You can try out ViT Shapley using Colab Open In Colab

Citation

If you use any part of this code and pretrained weights for your own purpose, please cite our paper.

Contact