Closed onvungocminh closed 1 year ago
We are using ViT implemented based on timm, and only need to modify a little code to extract 'embedding' features, as follows:
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
# return x if pre_logits else self.head(x)
t_emb = x
return t_emb, self.head(x)
def forward(self, x):
x = self.forward_features(x)
t_emb, x = self.forward_head(x)
return t_emb, x
Of course, you can also use hooks. And can use emb_fea_distribution.py to produce the center_emb_train.json file
Hi authors, Could you please tell me how to reproduce the center_emb_train.json file for ViT model? Thank you in advance.