icon-lab / ResViT

Official Implementation of ResViT: Residual Vision Transformers for Multi-modal Medical Image Synthesis
Other
133 stars 26 forks source link

How to generate attention maps by ResViT? #16

Closed fiy2W closed 1 year ago

fiy2W commented 1 year ago

Could you please provide codes for ResViT to visualize the attention maps? Attention rollout and attention flow are mentioned in the paper, but it is unclear how to realize them.

onat-dalmaz commented 1 year ago

Hello, Thanks for your interest. To visualize attention maps, you can use the following snippet below. Here attn_weights is the weights produced by the encoder, thus you need to set vis=True while initializing the transformer encoder.

if self.vis:
            a_tilda = torch.mean(attn_weights[0][0], 0)
            for i in range(len(attn_weights) - 1):
                raw_attention = torch.mean(attn_weights[i + 1][0], 0)
                a_tilda = torch.matmul(raw_attention, a_tilda)
            a_tilda = torch.unsqueeze(a_tilda, 0)
            B, n_patch, hidden = a_tilda.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
            h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
            a_tilda = a_tilda.permute(0, 2, 1)
            a_tilda = a_tilda.contiguous().view(B, hidden, h, w)
            a_tilda = torch.reshape(a_tilda, (1, 256, 256))
            a_tilda = torch.mean(a_tilda, 2)
            a_tilda = torch.reshape(a_tilda, (1, 1, 16, 16))
            # normalize
            a_tilda_max = a_tilda.max()
            a_tilda_min = a_tilda.min()
            a_tilda -= (a_tilda_max + a_tilda_min) / 2
            a_tilda *= 2 / (a_tilda_max - a_tilda_min)
            m = nn.Upsample(scale_factor=16, mode='bilinear')
            a_tilda = m(a_tilda)

Cheers,