wangchen1801 / FPD

Official code of the paper "Fine-Grained Prototypes Distillation for Few-Shot Object Detection (AAAI 2024)"
https://arxiv.org/pdf/2401.07629.pdf
27 stars 4 forks source link

Generation of attention heatmaps #10

Closed wml666666 closed 4 months ago

wml666666 commented 5 months ago

Sorry to bother you, but I would like to ask how the attention heatmap below was generated. Can you provide the relevant code? Thank you very much! image

wangchen1801 commented 5 months ago

Hello, you can refer to the following code to generate attention heatmaps:

fpd/ffa.py

class PrototypesDistillation(BaseModule):

  def forward(self, support_feats, support_gt_labels=None, forward_novel=False, forward_novel_test=False, img_metas=None):
    # ...
    prototypes = torch.matmul(attn.softmax(-1), v)

    # attention heatmap of feature queries upon support images
    if len(support_gt_labels) == 0:
        pass
    else:
        import cv2
        import os
        os.makedirs('attention_heatmap', exist_ok=True)
        if not support_feats.size(0):
            sys.exit()
        if img_metas is not None:
            img = img_metas[-1]
        bs = len(support_gt_labels)
        prefix = 'novel' if forward_novel else 'base'
        for img_id in range(bs):
            gt_label = support_gt_labels[img_id].cpu().numpy()
            file_name = img_metas[img_id]['filename'].split('/')[-1]

            # # attn heat map
            attn2 = attn.squeeze(1).softmax(-1)  # (16, 5, 49)
            attn2 = (attn2 - attn2.min(dim=2, keepdim=True)[0]) / (attn2.max(dim=2, keepdim=True)[0] - attn2.min(dim=2, keepdim=True)[0])
            for q_id in range(attn2.size(1)):
                attn_hm = attn2[img_id, q_id:q_id+1, :].reshape(1, 1, support_feats_mp.size(-2), support_feats_mp.size(-1))
                attn_hm = F.interpolate(attn_hm, size=(224, 224), mode='bilinear', align_corners=True)[0].permute(1, 2, 0).cpu().numpy()
                mean = torch.ones((3, 224, 224)) * torch.tensor([103.530, 116.280, 123.675])[:, None, None]
                raw_im = img[img_id, :3].add(mean.cuda()).permute(1, 2, 0).detach().cpu().numpy()

                heatmap = cv2.applyColorMap((attn_hm * 255).astype('uint8'), cv2.COLORMAP_JET)
                result = cv2.addWeighted(raw_im.astype('uint8'), 0.6, heatmap, 0.4, 0)
                cv2.imwrite(f'attention_heatmap/{prefix}_class{gt_label}_{file_name.split(".")[0]}_query{q_id}.png', result)

    return weight, prototypes

fpd/fpd_detector.py

def forward_model_init():

  # prototypes distillation
  # weight_base, prototypes_base = self.roi_head.prototypes_distillation(
  #     base_support_feats, support_gt_labels=r_b_gts)
  # weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
  #     novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True)
  # todo
  img_metas.append(img)
  weight_base, prototypes_base = self.roi_head.prototypes_distillation(
      base_support_feats, support_gt_labels=r_b_gts, img_metas=img_metas)
  weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
      novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True, img_metas=img_metas)
wml666666 commented 5 months ago

Okay, thank you for your prompt reply!