hila-chefer / Transformer-MM-Explainability

[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
MIT License
801 stars 107 forks source link

Generate relevance matrix in ViT of Hugging Face #13

Closed SketchX-QZY closed 2 years ago

SketchX-QZY commented 2 years ago

Hi, thank you for this great work!

I have trained a Transformer model with ViT - HuggingFace. When I tried to visualise the attention maps I found your work. I am quite interesting but I find your code and HuggingFace's are different. I tried to modify the source code like this.

class ViTLayer(nn.Module):

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )

        self_attention_outputs.register_hook(self.save_attn_gradients)

I am new to Transformer. I am not sure whether I register the hook in the right tensor. Can you help me check it?

Thank you very much!

hila-chefer commented 2 years ago

Hi @SketchX-QZY, thanks for your interest in our work! I can't really see the attention implementation itself here, but you need to register the hook on the attention heads after the softmax operation. See this code for example. If indeed self_attention_outputs contains the attention heads after softmax it should be fine.

Best, Hila.

hila-chefer commented 2 years ago

@SketchX-QZY closing due to inactivity, please reopen if necessary.