hila-chefer / Transformer-MM-Explainability

[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
MIT License
801 stars 107 forks source link

Batching for CLIP Explainability #11

Closed Alacarter closed 2 years ago

Alacarter commented 2 years ago

Is there a way to batch the heatmap creation process in CLIP_explainability? Specifically, I would like to pass in a batch_size-length list of (image, text) pairs and get a list of batch_size heatmaps, one per (image, text) pair. (I think this was asked in your other repo; you referenced the batching in the new ViT notebook, but it looks quite similar to the CLIP ipynb in this repo, and I don’t see how either supports creating multiple heatmaps for multiple images at once. Perhaps I am missing something.)

Batching the forward pass is straightforward, but it seems difficult to batch the backward pass in interpret, when we are starting with the relevant image logit to call backward on. Currently, I call one_hot.backward(.) on every image in the batch, which is quite time consuming.

Thanks for making your code public, and I appreciate your help in advance.

hila-chefer commented 2 years ago

Hi @Alacarter, thanks for your interest in our work! I updated our colab notebook to demonstrate how batching can be done. In our notebook, I used a single image with multiple texts. In order to create a "one-hot" vector, I duplicated the image batch_size times and then created a "one-hot" vector for each pair of (image, text). Similarly, if you have a batch of images and texts, simply create a one hot with the shape batch_zise x batch_size and set it to be the identity matrix.

I hope this helps, please refer to our colab for all the details :) Hila.

Alacarter commented 2 years ago

Thank you very much for the batch example in the colab! I am able to replicate your batching code on my end, and the heatmaps look similar to what I had before and generate a lot faster. Appreciate the help and response!