hila-chefer / Transformer-MM-Explainability

[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
MIT License
801 stars 107 forks source link

Is this really using the technique from the publication? #31

Closed entrity closed 1 year ago

entrity commented 1 year ago

The top of this repo's README links the article [ICCV 2021- Oral] PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, but I'm looking at the CLIP_explainability.ipynb notebook, and it appears to me as if this does not demonstrate the technique introduced by the paper. Have I missed something? Or should this notebook be updated?

Each of the examples in the notebook does no more than give heatmaps using the output of the helper function interpret, which makes use of simple self-attention, whereas the publication computed self-attention with effects from co-attention (cf. equation 11).

Here's a relevant excerpt from interpret with comments to indicate correspondence between the python code and the publication. Equation 11 is absent.

    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) # eq 1: self-attn Relevancy map
    R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(image_attn_blocks):
        if i < start_layer:
          continue
        grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
        cam = blk.attn_probs.detach() # A (attention map)
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam # A-bar, eq 5
        cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        cam = cam.clamp(min=0).mean(dim=1)
        R = R + torch.bmm(cam, R) # eq 6. It's not eq 7 b/c 7 starts from an R which is zeros, whereas this starts from an R which is identity.
hila-chefer commented 1 year ago

Hi @entrity, Yes, it really applies the technique from the paper. As mentioned in the paper, we support all attention based architectures, including pure self attention (ViT or CLIP). For pure self attention models the co attention rules are not required, however this is not a simple visualization of the attention maps. We weight the attention by the gradients in order to average across the attention heads. You can also control the number of layers you wish to propagate back from.

overall- the answer is yes, this is the method from the paper, applied over a pure self attention model. We have similar experiments in the paper with ViT and a notebook for that as well.