Closed wml666666 closed 4 months ago
Hello, you can refer to the following code to generate attention heatmaps:
fpd/ffa.py
class PrototypesDistillation(BaseModule):
def forward(self, support_feats, support_gt_labels=None, forward_novel=False, forward_novel_test=False, img_metas=None):
# ...
prototypes = torch.matmul(attn.softmax(-1), v)
# attention heatmap of feature queries upon support images
if len(support_gt_labels) == 0:
pass
else:
import cv2
import os
os.makedirs('attention_heatmap', exist_ok=True)
if not support_feats.size(0):
sys.exit()
if img_metas is not None:
img = img_metas[-1]
bs = len(support_gt_labels)
prefix = 'novel' if forward_novel else 'base'
for img_id in range(bs):
gt_label = support_gt_labels[img_id].cpu().numpy()
file_name = img_metas[img_id]['filename'].split('/')[-1]
# # attn heat map
attn2 = attn.squeeze(1).softmax(-1) # (16, 5, 49)
attn2 = (attn2 - attn2.min(dim=2, keepdim=True)[0]) / (attn2.max(dim=2, keepdim=True)[0] - attn2.min(dim=2, keepdim=True)[0])
for q_id in range(attn2.size(1)):
attn_hm = attn2[img_id, q_id:q_id+1, :].reshape(1, 1, support_feats_mp.size(-2), support_feats_mp.size(-1))
attn_hm = F.interpolate(attn_hm, size=(224, 224), mode='bilinear', align_corners=True)[0].permute(1, 2, 0).cpu().numpy()
mean = torch.ones((3, 224, 224)) * torch.tensor([103.530, 116.280, 123.675])[:, None, None]
raw_im = img[img_id, :3].add(mean.cuda()).permute(1, 2, 0).detach().cpu().numpy()
heatmap = cv2.applyColorMap((attn_hm * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(raw_im.astype('uint8'), 0.6, heatmap, 0.4, 0)
cv2.imwrite(f'attention_heatmap/{prefix}_class{gt_label}_{file_name.split(".")[0]}_query{q_id}.png', result)
return weight, prototypes
fpd/fpd_detector.py
def forward_model_init():
# prototypes distillation
# weight_base, prototypes_base = self.roi_head.prototypes_distillation(
# base_support_feats, support_gt_labels=r_b_gts)
# weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
# novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True)
# todo
img_metas.append(img)
weight_base, prototypes_base = self.roi_head.prototypes_distillation(
base_support_feats, support_gt_labels=r_b_gts, img_metas=img_metas)
weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True, img_metas=img_metas)
Okay, thank you for your prompt reply!
Sorry to bother you, but I would like to ask how the attention heatmap below was generated. Can you provide the relevant code? Thank you very much!