szc19990412 / TransMIL

TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification
325 stars 72 forks source link

attention weights for heatmaps #2

Closed gabemarx closed 2 years ago

gabemarx commented 2 years ago

Hello,

Really fantastic work, thank you for the excellent repo.

I want to apply this to my own dataset and would love to produce attention heatmaps like you did in the paper. I understand that the Nystromformer has a return_attn argument however I am confused by the dimensions that it returns. I played around with a toy dataset that had 1000 instances and it returned a 1 x nheads x 1280 x 1280 tensor. Confused how to take that 1280x1280 array and tease out the cls_token attention values.

Any advice?

Thanks so much!

gabemarx commented 2 years ago

After some digging I think I may have found my answer, but I would love for you to verify.

There are two separate sources of padding of n (number of instances):

Let's say the input x is a B x n x D shaped vector

First is in your code to prepare the array for PPEG. add_length (add_length = ceil(sqrt(n))^2 - n) is added to the end of the stack of vectors.

Then cls_token is added to the head of the stack.

So, h, the input of to the attention module is B x (n + add_length + 1) x D. Let's say n2 = n + add_length + 1.

The second source of padding is in the Nystromformer so that the sequence can be evenly divided by m (256) landmarks.

Then, if n2 % m > 0, padding = m - (n2 %m), else padding = 0. Padding is added to the front of the stack of arrays.

The output dimensions of the attention matrix are then B x n_heads x (n + add_length + padding + 1) x (n + add_length + padding + 1) So if I would like to slice out the relevant attention matrix (my feature vectors + cls_token) for all attention heads I would index: attention_matrix[0, :, padding:(padding + n +1), padding:(padding + n +1)]

I would love confirmation of whether this is correct. I would hate to incorrectly index my attention values and make incorrect heatmaps.

szc19990412 commented 2 years ago

Yes, take out the attention weight value corresponding to cls_token

Ajaz-Ahmad commented 2 years ago

Hi,

I have very similar question, I am trying to generate the attention map. The attention matrix shape I have is as below: (1, 8, 5281,5281) where number of tiles in my slide has 5281. Now My question is what does 8 (number of heads) indicates and how to utilize that to plot heatmaps Do I need to merge the score with following step: (1, 5281,8,5281) then (1,5281,42248) Thanks in advance.

YunanWu2168 commented 1 year ago

I had the same question here. Actually the attention matrix will have a dimention of the heads, right? For example, the batch size the num of heads feature map. Does anyone know how we can merge the score of the 8 heads? Thank you!

DeVriesMatt commented 1 year ago

@YunanWu2168 did you figure this out?