Open ParnianA opened 3 years ago
I'm trying to figure out the same thing
Using the below code I was able to visualize the attention maps.
Step 1:
In transformer.py under class MultiHeadedSelfAttention(nn.Module):
replace the forward method with the below code
def forward(self, x, mask): """ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) mask : (B(batch_size) x S(seq_len)) * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) if mask is not None: mask = mask[:, None, None, :].float() scores -= 10000.0 * (1.0 - mask) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D) h = merge_last(h, 2) self.scores = scores return h
Step 2:
In the Transformer.py under class Transformer(nn.Module)
replace the forward method with the below code
def forward(self, x, mask=None): atten_scores = [] for block in self.blocks: x = block(x, mask) atten_scores.append(block.attn.scores) return x,atten_scores
Step 3: In model.py under class 'class ViT(nn.Module)' replace the forward method with the below code
def forward(self, x): b, c, fh, fw = x.shape x = self.patch_embedding(x) # b,d,gh,gw x = x.flatten(2).transpose(1, 2) # b,gh*gw,d if hasattr(self, 'class_token'): x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d if hasattr(self, 'positional_embedding'): x = self.positional_embedding(x) # b,gh*gw+1,d x,atten_scores = self.transformer(x) # b,gh*gw+1,d att_mat = torch.stack(atten_scores).squeeze(1) att_mat = torch.mean(att_mat, dim=1) # print("att_mat",att_mat.shape) if hasattr(self, 'pre_logits'): x = self.pre_logits(x) x = torch.tanh(x) if hasattr(self, 'fc'): x = self.norm(x)[:, 0] # b,d x = self.fc(x) # b,num_classes return x,att_mat
Step 4:
Now in forward pass will return output of MLP layer and the activation map.
x,atten_weights = model.forward(input_image.unsqueeze(0))
here atten_weights will contain the activation maps
Step 5: Iterate through each atten_weights and visualize those
from PIL import Image import matplotlib.pyplot as plt im = Image.open(img_pth)
for att_mat in atten_weights: residual_att = torch.eye(att_mat.size(1)) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) joint_attentions = torch.zeros(aug_att_mat.size()) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1]) v = joint_attentions grid_size = int(np.sqrt(aug_att_mat.size(-1))) mask = v[0,1:].reshape(grid_size, grid_size).detach().numpy() mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis] result = (mask * im).astype("uint8") fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16)) ax1.set_title('Original') ax2.set_title('Attention Map') _ = ax1.imshow(im) _ = ax2.imshow(result)
Could you please share final code or any colab demo for extract attention map @gouttham gouttham
Could you please share final code or any colab demo for extract attention map @gouttham gouttham
https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb
Hi. Does anyone know how we can have access to attention maps?