jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
https://jacobgil.github.io/pytorch-gradcam-book
MIT License
10.18k stars 1.53k forks source link

Support grad cam for cross attention on encoder-decoder models #454

Open ahmedplateiq opened 1 year ago

ahmedplateiq commented 1 year ago

Currently, encoder-decoder models lack support for Grad-CAM (Gradient-weighted Class Activation Mapping) visualization with cross-attention mechanisms. Grad-CAM is a valuable tool for interpreting model decisions and understanding which parts of the input contribute most to the output. Extending Grad-CAM support to cross-attention models would greatly enhance their interpretability and utility.

Proposal We propose adding Grad-CAM support specifically tailored for cross-attention mechanisms in our encoder-decoder models. This would allow users to visualize the attention weights between encoder and decoder, shedding light on how information flows between these components during inference.

Implementation Ideas Here are some high-level steps to implement Grad-CAM support for cross-attention:

Identify the cross-attention layers in the encoder-decoder architecture. Compute the gradients of the output with respect to the activations of these cross-attention layers. Aggregate these gradients to create class-specific importance scores. Generate the Grad-CAM heatmap for visualization.

Benefits

Improved model interpretability: Users can gain insights into how the model attends to different parts of the input during decoding. Debugging and model refinement: Grad-CAM can help diagnose model behavior and identify areas for model improvements

Example: This would help for example in the Donut encoder decoder model to generate heat maps using gradcam from cross attention outputs to identify what part of the image are predicted by which text token. Refer to the following discussion:

https://github.com/clovaai/donut/issues/45

jcjlin commented 1 month ago

Hello @ahmedplateiq do you know how to implement the GradCAM visualization on a decoder? thanks.