wangchen1801 / FPD

Official code of the paper "Fine-Grained Prototypes Distillation for Few-Shot Object Detection (AAAI 2024)"
27 stars 4 forks source link

Feature map #8

Open ZhaoPo opened 6 months ago

ZhaoPo commented 6 months ago

Hello, it is a very meaningful work, but I see that your paper has an Attention heatmap for querying pictures. How can I reproduce it? Is there any relevant script? I am looking forward to the author's reply very much a7f607220cf5a5b9d801acb6db2a2dce

wangchen1801 commented 6 months ago

Thank you for your interest in this work! Given a feature map of size (B,C,H,W), we can get the heatmap with following script:

import cv2
import torch.nn.functional as F

q0 = query_feature[0].sum(dim=0, keepdim=True)
q0 = (q0 - q0.min()) / (q0.max() - q0.min())
q0 = F.interpolate(q0.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()
heatmap = cv2.applyColorMap((q0 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_1_query_feature.png', result)

We can apply this code at line 127 in Please note that the parameter 'num_bg' should be set to 0 in order to reproduce the results, otherwise this visualization method will become ineffective. Inkedissue8

ZhaoPo commented 6 months ago

Thank you very much for replying, but my level of coding, the level of a beginner, I didn't implement it for the past few days, is there a more detailed step by step, I am running the file. Looking forward to your reply. 1cdb76956c818316108688e41b6e086d 409922f5006d0fdb26dfaec9c7086731

wangchen1801 commented 6 months ago

Hello, here is the completed code, you can try this out. Feel free to ask if you have any other questions.

# out = query_feature + self.gamma * out
fused_feature = query_feature + self.gamma * out

import cv2
import torch.nn.functional as F

img = query_img_metas[0]['img']
_, _, H, W = img.shape
mean = torch.ones((3, H, W)) * torch.tensor([103.530, 116.280, 123.675])[:, None, None]
im = img[0, ...].add(mean.cuda()).permute(1, 2, 0).detach().cpu().numpy()

file_name = query_img_metas[0]['filename'].split('/')
file_name = file_name[-3] + file_name[-1]
cv2.imwrite(file_name.split('.')[0] + '_0_raw_image.png', im)

q0 = query_feature[0].sum(dim=0, keepdim=True)
q0 = (q0 - q0.min()) / (q0.max() - q0.min())
q0 = F.interpolate(q0.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()
heatmap = cv2.applyColorMap((q0 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_1_query_feature.png', result)

q1 = out[0].sum(dim=0, keepdim=True)
q1 = (q1 - q1.min()) / (q1.max() - q1.min())
q1 = F.interpolate(q1.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()

heatmap = cv2.applyColorMap((q1 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_2_assigned_prototypes.png', result)

return fused_feature
ZhaoPo commented 6 months ago

Thank you very much for your reply, I added the code you provided in and running shows the following problem, can you provide some good suggestions, thanks! 5956ac254b2e7d0f2be6988e73726ca7

wangchen1801 commented 6 months ago

Sorry for the late reply, please try to test images in different dataset: configs/base/datasets/nway_kshot/

      # ann_cfg=[
      #     dict(
      #         type='ann_file',
      #         ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'),
      # ],  # todo
              ann_file=data_root +
hbc666666 commented 6 months ago

Hello, thank you very much for your reply, I changed the code you provided in and test images in different dataset, but the problem still remains. Maybe KeyError: 'img' means that an element with the key 'img' cannot be found in query_img_metas[0]. This could be because the element in query_img_metas does not contain the 'img' key, or query_img_metas itself is empty. Can you provide some good suggestions, thanks!

wangchen1801 commented 6 months ago

Hello, I checked the previous code and you are right. We need to add "img_metas[0]['img'] = img" at:


def simple_test(self,
                img: Tensor,
                img_metas: List[Dict],
                proposals: Optional[List[Tensor]] = None,
                rescale: bool = False):

    assert self.with_bbox, 'Bbox head must be implemented.'
    assert len(img_metas) == 1, 'Only support single image inference.'
    if not self.is_model_init:
        # process the saved support features

    query_feats = self.extract_feat(img)
    img_metas[0]['img'] = img  # test time visualization