Open kimsekeun opened 5 months ago
I wonder that how did you visualize attention in final vit.
In my opinion,
Given x, Y = forward_encoder (x) Then Y2 = forward_decoder(y) , in this step did you used x1 or x2?
apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x1 = self.decoder_norm(x)
predictor projection X2 = self.decoder_pred(x1)
---code---
def vis_attention(idx, batch, model): x = batch["image"]
x = x.unsqueeze(0) #ncthw attention_last = model.get_last_selfattention(x.cuda()) nh = 16 #decoder numbed ch = attention_last.shape[2] dim = ch//nh attention_last = attention_last.view(nh, 1568, dim) attention_head_feature = attention_last[0] attention_head_feature = attention_head_feature.view(14, 14, 8, dim) # sum along the last dimension to get the attention map attention_map = attention_head_feature.sum(dim=-1) # normalize attention map to [0, 1] attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) # attention_map = attention_map * torch.int((attention_map > 0.5)) # visualize attention map for each frame for frame_idx in range(8): frame_attention = attention_map[:, :, frame_idx] resized_attention_map = F.interpolate(frame_attention.unsqueeze(0).unsqueeze(0), size=(224, 224), mode="bilinear")[0][0] plt.subplot(1, 3, 1) vis_img = torch.einsum('chw->hwc', x[0,:,frame_idx,:,:]) show_image(vis_img , "original") plt.subplot(1, 3, 2) plt.imshow(resized_attention_map.cpu().detach().numpy(), cmap='hot', interpolation='bilinear') plt.title(f'Attention Map - Frame {frame_idx + 1}') plt.subplot(1, 3, 3) plt.imshow(vis_img[:,:,0] *resized_attention_map.cpu().detach().numpy() , cmap='hot', interpolation='bilinear') plt.savefig( os.path.join( dest_attn_dir, str(idx) + "_" + str(frame_idx))) plt.show()
-- def prepare_tokens(self, x):
latent, mask, ids_restore = self.forward_encoder(x, 0.75) pred = self.forward_decoder_get_last_attn(latent, ids_restore) # [N, L, p*p*3] return pred, mask return x
def get_last_selfattention(self, x): x, mask = self.prepare_tokens(x) return x
def forward_decoder_get_last_attn(self, x, ids_restore): N = x.shape[0] T = self.patch_embed.t_grid_size H = W = self.patch_embed.grid_size
# embed tokens x = self.decoder_embed(x) C = x.shape[-1] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1) x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token x_ = x_.view([N, T * H * W, C]) x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2]) ) # unshuffle x = x_.view([N, T * H * W, C]) # append cls token if self.cls_embed: decoder_cls_token = self.decoder_cls_token decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1) x = torch.cat((decoder_cls_tokens, x), dim=1) if self.sep_pos_embed: decoder_pos_embed = self.decoder_pos_embed_spatial.repeat( 1, self.input_size[0], 1 ) + torch.repeat_interleave( self.decoder_pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1, ) if self.cls_embed: decoder_pos_embed = torch.cat( [ self.decoder_pos_embed_class.expand( decoder_pos_embed.shape[0], -1, -1 ), decoder_pos_embed, ], 1, ) else: decoder_pos_embed = self.decoder_pos_embed[:, :, :] # add pos embed x = x + decoder_pos_embed attn = self.decoder_blocks[0].attn requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape if requires_t_shape: x = x.view([N, T, H * W, C]) # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) x = x[:, 1:, :] return x
Thank you.
I wonder that how did you visualize attention in final vit.
In my opinion,
Given x, Y = forward_encoder (x) Then Y2 = forward_decoder(y) , in this step did you used x1 or x2?
apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x1 = self.decoder_norm(x)
predictor projection X2 = self.decoder_pred(x1)
---code---
main
def vis_attention(idx, batch, model):
x = batch["image"]
-- def prepare_tokens(self, x):
def get_last_selfattention(self, x): x, mask = self.prepare_tokens(x) return x
def forward_decoder_get_last_attn(self, x, ids_restore): N = x.shape[0] T = self.patch_embed.t_grid_size H = W = self.patch_embed.grid_size
Thank you.