hila-chefer / Transformer-MM-Explainability

[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
MIT License
801 stars 107 forks source link

attn_grad #3

Closed betterze closed 3 years ago

betterze commented 3 years ago

Dear Hila,

Thank you for your work, I really like it.

In clip nootbook,

''' image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values()) '''

then

''' grad = blk.attn_grad cam = blk.attn_probs '''

If I understand correctly, each blk is a clip ResidualAttentionBlock. But there is not attn_grad or attn_probs in ResidualAttentionBlock class, they are inherit from nn.Module? I try to google it, but I can not find related resource.

Similarly, in ViT nootbook, there are

''' grad = blk.attn.get_attn_gradients() cam = blk.attn.get_attention_map() '''

The function is from here, but I still have trouble to understand how you get the gradient and attention map. Sorry, I am new to torch.

Could you help me to understand your implementation?

Thank you for your help.

Best Wishes,

Alex

hila-chefer commented 3 years ago

Hi @betterze, thanks for your kind words and for your interest in our work!

Let me try to clarify, but if my answer isn't good enough please let me know: let's take the ViT example because it's simpler. If we look at the ViT_new.py code, we can see the lines that are responsible for saving the gradients and attention maps: self.save_attention_map(attn) attn.register_hook(self.save_attn_gradients) If you notice, the first thing we do in the ViT notebook's generate_relevance function is to run a forward pass with the example we wish to explain: output = model(input, register_hook=True), this triggers the "forward hook " that we set in this line: self.save_attention_map(attn), i.e. for each self-attention block this forward pass saves the self-attention matrix in self.attention_map. After this forward pass, each attention block stores its attention map in self.attention_map.

Moving to the gradients, this line: attn.register_hook(self.save_attn_gradients) registers a backward hook on the attn tensor, such that when we backpropagate the gradients, the hook will be called for the attention map. This hook saves the attention-map's gradient in self.attn_gradients. The backward pass is also triggered from the generate_relevance function in the ViT notebook in line: one_hot.backward(retain_graph=True). After this step, we have both gradients and attention maps saved in each self-attention block. All we need to do is iterate over the blocks and apply the rules on the maps + gradients, as you mentioned: grad = blk.attn.get_attn_gradients() cam = blk.attn.get_attention_map() I really hope this helps, but if somehow I failed to answer your question, please let me know and I'll clarify as needed.

Thanks :)

betterze commented 3 years ago

Dear Hila,

Thank you very much for your detail answer. It is really helpful.

After a few experiments, I believe I understand it now.

Thank you again for your help, I really appreciate it.

Best Wishes,

Alex