tczhangzhi / VisionTransformer-Pytorch

Apache License 2.0
65 stars 10 forks source link

How to visualize attention map #1

Open piantic opened 3 years ago

piantic commented 3 years ago

Hi,

I want to visualize attention map. I found https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb

In this repo, I did not found vis option for attention map. (If any, please let me know and I'd appreciate it.)

So, I decided to add this to model.py. like this:

# In VisionTransformer
def forward(self, x):
    feat, attn_weights = self.extract_features(x)

    # classifier
    logits = self.classifier(feat[:, 0])
    return logits, attn_weights
# In Encoder
def forward(self, x):
    attn_weights = []
    out = self.pos_embedding(x)

    for layer in self.encoder_layers:
        out, weights = layer(out)
        attn_weights.append(weights)

    out = self.norm(out)
    return out, attn_weights
# In SelfAttention
def forward(self, x):
    b, n, _ = x.shape

    q = self.query(x, dims=([2], [0]))
    k = self.key(x, dims=([2], [0]))
    v = self.value(x, dims=([2], [0]))

    q = q.permute(0, 2, 1, 3)
    k = k.permute(0, 2, 1, 3)
    v = v.permute(0, 2, 1, 3)

    attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
    attn_weights = F.softmax(attn_weights, dim=-1)
    out = torch.matmul(attn_weights, v)
    out = out.permute(0, 2, 1, 3)

    out = self.out(out, dims=([2, 3], [0, 1]))

    return out, attn_weights

And I got the result.

image

But I don't know that it is right or not. Because the result of attention map above link is quite different for me. (I used pretrained weights in here).

image

I am not sure if my results are correct. I would be happy if I could hear the answer.

Thanks.

tczhangzhi commented 3 years ago

Looks good to me but one thing you should pay attention to is that vit-model-1 is finetuned on the cassava-leaf-disease-classification task. You may expect to visualize an image from that dataset. It is quite different from object classification and focuses on the low-level texture of the input leaf. To visualize the attention map of a dog, you can utilize pre-trained models here.

Anyway, it is a good first try. I'm still hesitating about the operation of extracting the "attention Map" since I don't want it to affect the inference process, that is, to modify the forward function. Maybe later I will check some best practices about hooks. If u r willing to, u can make a PR of your implement.

piantic commented 3 years ago

Thanks for answer. I used your recommended pre-trained models.

Here is result for a dog.

image

Original attention map in repo for https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb is below:

image

It seems to be something, but I'm not sure. What do you think about this part?

If you think this part is okay, It seems that a simple flag-vis can minimize the influence of inference.

If vis is False, Model is working original forward function.

cifkao commented 2 years ago

For people still looking for a solution, my package NoPdb allows capturing attention weights from pretty much any Transformer implementation without any modifications to the code. See a Colab notebook showing how to do this for ViT (a different implementation).

In this case, it would be something like:

with nopdb.capture_calls(SelfAttention.forward) as calls:
    logits = model(x)

calls[0].locals["attn_weights"]  # attention weights of the first layer
Suryanshg commented 2 years ago

Hi, when I try to implement the changes by @piantic, this is the error I am getting:

Traceback (most recent call last): File "C:\Users\Surya\Desktop\Automatic-Pain-Estimation-MQP\scripts\Visualize_Attention_Map.py", line 96, in result_img = get_attention_map(viz_image) File "C:\Users\Surya\Desktop\Automatic-Pain-Estimation-MQP\scripts\Visualize_Attention_Map.py", line 25, in get_attention_map logits, att_mat = model(x.unsqueeze(0)) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 269, in forward feat, attn_weights = self.extract_features(x) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 265, in extract_features
feat = self.transformer(emb) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 177, in forward out, weights = layer(out) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(input,
kwargs) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 139, in forward out = self.dropout(out) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\dropout.py", line 58, in forward return F.dropout(input, self.p, self.training, self.inplace) File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py", line 1169, in dropout return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) TypeError: dropout(): argument 'input' (position 1) must be Tensor, not tuple

Is there anything else I need to do? I feel that there might be some change that needs to be made in the EncoderBlock part of the model.py file

piantic commented 2 years ago

Hi, @Suryanshg.

This is my example notebook for visualizing attention map using this github. https://www.kaggle.com/code/piantic/vision-transformer-vit-visualize-attention-map/notebook

And you can see visualized version of ViT in below link. https://www.kaggle.com/datasets/piantic/visiontransformerpytorch121

I hope this helps you. Thanks.