facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
6.93k stars 1.17k forks source link

visualization attention map. #187

Open kimsekeun opened 5 months ago

kimsekeun commented 5 months ago

I wonder that how did you visualize attention in final vit.

In my opinion,

Given x, Y = forward_encoder (x) Then Y2 = forward_decoder(y) , in this step did you used x1 or x2?

apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x1 = self.decoder_norm(x)

predictor projection X2 = self.decoder_pred(x1)

---code---

main

def vis_attention(idx, batch, model):
x = batch["image"]

x = x.unsqueeze(0) #ncthw

attention_last = model.get_last_selfattention(x.cuda())

nh = 16  #decoder numbed
ch = attention_last.shape[2]
dim = ch//nh

attention_last = attention_last.view(nh, 1568, dim)

attention_head_feature = attention_last[0]

attention_head_feature = attention_head_feature.view(14, 14, 8, dim)

# sum along the last dimension to get the attention map
attention_map = attention_head_feature.sum(dim=-1)

# normalize attention map to [0, 1]
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())

# attention_map =  attention_map  * torch.int((attention_map > 0.5))
# visualize attention map for each frame
for frame_idx in range(8):
    frame_attention = attention_map[:, :, frame_idx]

    resized_attention_map = F.interpolate(frame_attention.unsqueeze(0).unsqueeze(0), size=(224, 224), mode="bilinear")[0][0]

    plt.subplot(1, 3, 1)
    vis_img = torch.einsum('chw->hwc', x[0,:,frame_idx,:,:])
    show_image(vis_img , "original")

    plt.subplot(1, 3, 2)
    plt.imshow(resized_attention_map.cpu().detach().numpy(), cmap='hot', interpolation='bilinear')
    plt.title(f'Attention Map - Frame {frame_idx + 1}')

    plt.subplot(1, 3, 3)
    plt.imshow(vis_img[:,:,0] *resized_attention_map.cpu().detach().numpy() , cmap='hot', interpolation='bilinear')

    plt.savefig( os.path.join( dest_attn_dir, str(idx) + "_" + str(frame_idx)))
    plt.show()

-- def prepare_tokens(self, x):

    latent, mask, ids_restore = self.forward_encoder(x, 0.75)
    pred = self.forward_decoder_get_last_attn(latent, ids_restore)  # [N, L, p*p*3]
    return pred, mask

    return x

def get_last_selfattention(self, x): x, mask = self.prepare_tokens(x) return x

def forward_decoder_get_last_attn(self, x, ids_restore): N = x.shape[0] T = self.patch_embed.t_grid_size H = W = self.patch_embed.grid_size

    # embed tokens
    x = self.decoder_embed(x)
    C = x.shape[-1]

    # append mask tokens to sequence
    mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1)
    x_ = torch.cat([x[:, :, :], mask_tokens], dim=1)  # no cls token
    x_ = x_.view([N, T * H * W, C])
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
    )  # unshuffle
    x = x_.view([N, T * H * W, C])
    # append cls token
    if self.cls_embed:
        decoder_cls_token = self.decoder_cls_token
        decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((decoder_cls_tokens, x), dim=1)

    if self.sep_pos_embed:
        decoder_pos_embed = self.decoder_pos_embed_spatial.repeat(
            1, self.input_size[0], 1
        ) + torch.repeat_interleave(
            self.decoder_pos_embed_temporal,
            self.input_size[1] * self.input_size[2],
            dim=1,
        )
        if self.cls_embed:
            decoder_pos_embed = torch.cat(
                [
                    self.decoder_pos_embed_class.expand(
                        decoder_pos_embed.shape[0], -1, -1
                    ),
                    decoder_pos_embed,
                ],
                1,
            )
    else:
        decoder_pos_embed = self.decoder_pos_embed[:, :, :]

    # add pos embed
    x = x + decoder_pos_embed

    attn = self.decoder_blocks[0].attn
    requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape
    if requires_t_shape:
        x = x.view([N, T, H * W, C])

    # apply Transformer blocks
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    x = x[:, 1:, :]

    return x

Thank you.