lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
779 stars 124 forks source link

Visualizing attention map #19

Open ParnianA opened 3 years ago

ParnianA commented 3 years ago

Hi. Does anyone know how we can have access to attention maps?

tolaut commented 3 years ago

I'm trying to figure out the same thing

gouttham commented 3 years ago

Using the below code I was able to visualize the attention maps.

Step 1: In transformer.py under class MultiHeadedSelfAttention(nn.Module): replace the forward method with the below code

def forward(self, x, mask):
    """
    x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
    mask : (B(batch_size) x S(seq_len))
    * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
    """
    # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
    q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
    q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
    # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
    scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
    if mask is not None:
        mask = mask[:, None, None, :].float()
        scores -= 10000.0 * (1.0 - mask)
    scores = self.drop(F.softmax(scores, dim=-1))
    # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
    h = (scores @ v).transpose(1, 2).contiguous()
    # -merge-> (B, S, D)
    h = merge_last(h, 2)
    self.scores = scores
    return h

Step 2: In the Transformer.py under class Transformer(nn.Module) replace the forward method with the below code

def forward(self, x, mask=None):
    atten_scores = []
    for block in self.blocks:
        x = block(x, mask)
        atten_scores.append(block.attn.scores)
    return x,atten_scores

Step 3: In model.py under class 'class ViT(nn.Module)' replace the forward method with the below code

def forward(self, x):
    b, c, fh, fw = x.shape
    x = self.patch_embedding(x)  # b,d,gh,gw
    x = x.flatten(2).transpose(1, 2)  # b,gh*gw,d
    if hasattr(self, 'class_token'):
        x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)  # b,gh*gw+1,d
    if hasattr(self, 'positional_embedding'): 
        x = self.positional_embedding(x)  # b,gh*gw+1,d 
    x,atten_scores = self.transformer(x)  # b,gh*gw+1,d
    att_mat = torch.stack(atten_scores).squeeze(1)
    att_mat = torch.mean(att_mat, dim=1)
    # print("att_mat",att_mat.shape)
    if hasattr(self, 'pre_logits'):
        x = self.pre_logits(x)
        x = torch.tanh(x)
    if hasattr(self, 'fc'):
        x = self.norm(x)[:, 0]  # b,d
        x = self.fc(x)  # b,num_classes
    return x,att_mat

Step 4: Now in forward pass will return output of MLP layer and the activation map. x,atten_weights = model.forward(input_image.unsqueeze(0)) here atten_weights will contain the activation maps

Step 5: Iterate through each atten_weights and visualize those

from PIL import Image import matplotlib.pyplot as plt im = Image.open(img_pth)

for att_mat in atten_weights:
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]
    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])
    v = joint_attentions
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0,1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
    result = (mask * im).astype("uint8")
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map')
    _ = ax1.imshow(im)
    _ = ax2.imshow(result)
kiashann commented 2 years ago

Could you please share final code or any colab demo for extract attention map @gouttham gouttham

IJS1016 commented 2 years ago

Could you please share final code or any colab demo for extract attention map @gouttham gouttham

https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb