soniajoseph / ViT-Prisma

ViT Prisma is a mechanistic interpretability library for Vision Transformers (ViTs).
Other
179 stars 19 forks source link

Vision Transformer (ViT) Prisma Library

Logo Image 1 Logo Image 2 Logo Image 3

For a full introduction, including Open Problems in vision mechanistic interpretability, see the original Less Wrong post here.

ViT Prisma is an open-source mechanistic interpretability library for vision and multimodal models. Currently, the library supports ViTs and CLIP. This library was created by Sonia Joseph. ViT Prisma is largely based on TransformerLens by Neel Nanda.

Contributors: Praneet Suresh, Yash Vadi, Rob Graham [and more coming soon]

We welcome new contributors. Check out our contributing guidelines here and our open Issues.

Installing Repo

For the latest version, install the repo from the source. While this version will include the latest developments, they may not be fully tested.

For the tested and stable release, install Prisma as a package.

Install as a package Installing with pip:

pip install vit_prisma

Install from source To install as an editable repo from source:

git clone https://github.com/soniajoseph/ViT-Prisma
cd ViT-Prisma
pip install -e .

How do I use this repo?

Check out our guide.

Check out our tutorial notebooks for using the repo. You can also check out this corresponding talk on some of these techniques.

  1. Main ViT Demo - Overview of main mechanistic interpretability technique on a ViT, including direct logit attribution, attention head visualization, and activation patching. The activation patching switches the net's prediction from tabby cat to Border collie with a minimum ablation.
  2. Emoji Logit Lens - Deeper dive into layer- and patch-level predictions with interactive plots.
  3. Interactive Attention Head Tour - Deeper dive into the various types of attention heads a ViT contains with interactive JavaScript.

Features

For a full demo of Prisma's features, including the visualizations below with interactivity, check out the demo notebooks above.

Attention head visualization

Logo Image 1 Logo Image 2 Logo Image 3

Activation patching

Direct logit attribution

Emoji logit lens

Supported Models

Training Code

Prisma contains training code to train your own custom ViTs. Training small ViTs can be very useful when isolating specific behaviors in the model.

For training your own models, check out our guide.

Custom Models & Checkpoints

ImageNet-1k classification checkpoints (patch size 32)

This model was trained by Praneet Suresh. All models include training checkpoints, in case you want to analyze training dynamics.

This larger patch size ViT has inspectable attention heads; else the patch size 16 attention heads are too large to easily render in JavaScript.

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 3 0.22 | 0.42 N/A Attention+MLP

ImageNet-1k classification checkpoints (patch size 16)

The detailed training logs and metrics can be found here. These models were trained by Yash Vadi.

Table of Results

Accuracy [ <Acc> | <Top5 Acc> ]

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 1 0.16 | 0.33 0.11 | 0.25 AttentionOnly, Attention+MLP
base 2 0.23 | 0.44 0.16 | 0.34 AttentionOnly, Attention+MLP
small 3 0.28 | 0.51 0.17 | 0.35 AttentionOnly, Attention+MLP
medium 4 0.33 | 0.56 0.17 | 0.36 AttentionOnly, Attention+MLP

dSprites Shape Classification training checkpoints

Original dataset is here.

Full results and training setup are here. These models were trained by Yash Vadi.

Table of Results Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 1 0.535 0.459 AttentionOnly, Attention+MLP
base 2 0.996 0.685 AttentionOnly, Attention+MLP
small 3 1.000 0.774 AttentionOnly, Attention+MLP
medium 4 1.000 0.991 AttentionOnly, Attention+MLP

Guidelines for training + uploading models

Upload your trained models to Huggingface. Follow the Huggingface guidelines and also create a model card. Document as much of the training process as possible including links to loss and accuracy curves on weights and biases, dataset (and order of training data), hyperparameters, optimizer, learning rate schedule, hardware, and other details that may be relevant.

Include frequent checkpoints throughout training, which will help other researchers understand training dynamics.

Citation

Please cite this repository when used in papers or research projects.

@misc{joseph2023vit,
  author = {Sonia Joseph},
  title = {ViT Prisma: A Mechanistic Interpretability Library for Vision Transformers},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/soniajoseph/vit-prisma}}
}