Closed fiy2W closed 1 year ago
Hello, Thanks for your interest. To visualize attention maps, you can use the following snippet below. Here attn_weights is the weights produced by the encoder, thus you need to set vis=True while initializing the transformer encoder.
if self.vis:
a_tilda = torch.mean(attn_weights[0][0], 0)
for i in range(len(attn_weights) - 1):
raw_attention = torch.mean(attn_weights[i + 1][0], 0)
a_tilda = torch.matmul(raw_attention, a_tilda)
a_tilda = torch.unsqueeze(a_tilda, 0)
B, n_patch, hidden = a_tilda.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
a_tilda = a_tilda.permute(0, 2, 1)
a_tilda = a_tilda.contiguous().view(B, hidden, h, w)
a_tilda = torch.reshape(a_tilda, (1, 256, 256))
a_tilda = torch.mean(a_tilda, 2)
a_tilda = torch.reshape(a_tilda, (1, 1, 16, 16))
# normalize
a_tilda_max = a_tilda.max()
a_tilda_min = a_tilda.min()
a_tilda -= (a_tilda_max + a_tilda_min) / 2
a_tilda *= 2 / (a_tilda_max - a_tilda_min)
m = nn.Upsample(scale_factor=16, mode='bilinear')
a_tilda = m(a_tilda)
Cheers,
Could you please provide codes for ResViT to visualize the attention maps?
Attention rollout
andattention flow
are mentioned in the paper, but it is unclear how to realize them.