hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

Question about Raw Attention and GradCAM for transformer #36

Closed hanwei0912 closed 2 years ago

hanwei0912 commented 2 years ago

Hi Hila,

Thank you for this brilliant work!

I read your paper and I would like to ask for more details about the raw attention mentioned in your paper. To my understanding, the raw attention is: (1) taking the last attention map A^(1) (2) average according to head, get E_h(A^(1)), which shape is 1ss (3) choose the row for CLS, get vector 1s (4) reshape 1s to sqrt(s-1)*sqrt(s-1) (5) upsampling back to the size of the input image with bilinear interpolation.

I want to confirm with you because I got different visualization of raw attention. I would like to understand if I generate the same version as you did. Thank you very much if you could confirm it.

Besides, I am sorry I did follow the part about applying GradCAM for the transformer, is it then the same as https://github.com/jacobgil/pytorch-grad-cam/blob/master/tutorials/vision_transformers.md ? If not, could you please tell me how you implement it?

Thank you in advance, Looking forward to your reply, Hanwei

hila-chefer commented 2 years ago

Hi @hanwei0912, thanks for your interest!

If I understand your pseudo code correctly, it seems like you’re right. The result can vary according to the preprocessing you perform on the image (center crop or not, shape of the image etc) and of course can vary between different models (ViT, DeiT, etc) so I’m not sure what causes the difference, but I’d check these two factors. Regarding GradCAM, please refer to this issue for the implementation details of GradCAM in our work.