hila-chefer / Transformer-Explainability

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.
MIT License
1.75k stars 232 forks source link

Why did you multiply (Hadamard product) R(A) by the gradient of A? #39

Closed ThisisBillhe closed 2 years ago

ThisisBillhe commented 2 years ago

Thank you for the outstanding work!! However, I wonder why you multiply R(A) by G(A) to get the final result. According to my calculation, R(A) is equal to A * G(A) / C, where C is a constant. What would happen if we use R(A) alone? And what’s the motivation to multiply them together?

Looking forward to your reply!!

hila-chefer commented 2 years ago

Hi @ThisisBillhe, thanks for your interest!

The motivation is to average across the different attention heads. A simple average ignores the different "roles" of each head, so we use gradients to obtain a class-specific signal which determines the "relevance" of each head to the output prediction. Using R(A) alone will result in a class-agnostic and noisy relevance maps.

I hope this helps. Best, Hila.