jacobgil / vit-explain

Explainability for Vision Transformers
MIT License
858 stars 99 forks source link

normalize (sum to 1) attention score seems not right #16

Open jihwanp opened 2 years ago

jihwanp commented 2 years ago

Hi Thanks for sharing nice work.

I noticed that you've done normalizing attention score (row sum to 1) as mentioned in the original attention rollout paper.

I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0*I)/2
a = a / a.sum(dim=-1)

But it seems when dividing the summation of row attention score, keepdim=True should be apply to ensure that sum of row attention score after normalization should be 1.

a = a / a.sum(dim=-1,keepdim=True)

Maybe I'm wrong, please double check this issue. Thanks

vivekh2000 commented 6 months ago

@jacobgil , thanks for code. I think the following line in the code is redundant. https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L31 Reason:-I have attached the screenshot of the original paper below from page 3. image Here, the author said that the W_attn matrix is already normalized. When we add the identity matrix I, which is already a normalized matrix(meaning all the columns sum to one), multiplying by 0.5 makes W_attn plus I a normalized matrix.

Also at line 10 result = torch.eye(attentions[0].size(-1)) https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L10 result is an identity matrix, whereas at line 33 result = torch.matmul(a, result) https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L33 a matrix and result matrix(an identity matrix) are getting multiplied, this should result in a always. Further, as mentioned in the original paper, recursive multiplication is not implemented. Anyway, thanks for the nice implementation of the techniques.

eneserdo commented 4 months ago

@vivekh2000 I did not check the paper, but I think @jacobgil also implemented the discard_ratio which may not be available in the paper because this obviously breaks the normalization of the matrix. So, it is necessary to re-normalize the matrix. Also, I agreed with @jihwanp, there should be keepdim=True

gbZachYin commented 1 month ago

keepdim=True should be correct