Open ZhaoPo opened 6 months ago
Thank you for your interest in this work! Given a feature map of size (B,C,H,W), we can get the heatmap with following script:
import cv2
import torch.nn.functional as F
q0 = query_feature[0].sum(dim=0, keepdim=True)
q0 = (q0 - q0.min()) / (q0.max() - q0.min())
q0 = F.interpolate(q0.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()
heatmap = cv2.applyColorMap((q0 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_1_query_feature.png', result)
We can apply this code at line 127 in ffa.py. Please note that the parameter 'num_bg' should be set to 0 in order to reproduce the results, otherwise this visualization method will become ineffective.
Thank you very much for replying, but my level of coding, the level of a beginner, I didn't implement it for the past few days, is there a more detailed step by step, I am running the test.py file. Looking forward to your reply.
Hello, here is the completed code, you can try this out. Feel free to ask if you have any other questions.
# out = query_feature + self.gamma * out
fused_feature = query_feature + self.gamma * out
import cv2
import torch.nn.functional as F
img = query_img_metas[0]['img']
_, _, H, W = img.shape
mean = torch.ones((3, H, W)) * torch.tensor([103.530, 116.280, 123.675])[:, None, None]
im = img[0, ...].add(mean.cuda()).permute(1, 2, 0).detach().cpu().numpy()
file_name = query_img_metas[0]['filename'].split('/')
file_name = file_name[-3] + file_name[-1]
cv2.imwrite(file_name.split('.')[0] + '_0_raw_image.png', im)
torch.set_printoptions(profile="full")
q0 = query_feature[0].sum(dim=0, keepdim=True)
q0 = (q0 - q0.min()) / (q0.max() - q0.min())
q0 = F.interpolate(q0.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()
heatmap = cv2.applyColorMap((q0 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_1_query_feature.png', result)
q1 = out[0].sum(dim=0, keepdim=True)
q1 = (q1 - q1.min()) / (q1.max() - q1.min())
q1 = F.interpolate(q1.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0).permute(1, 2, 0).cpu().numpy()
heatmap = cv2.applyColorMap((q1 * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(im.astype('uint8'), 0.4, heatmap, 0.6, 0)
cv2.imwrite(file_name.split('.')[0] + '_2_assigned_prototypes.png', result)
return fused_feature
Thank you very much for your reply, I added the code you provided in ffa.py and running test.py shows the following problem, can you provide some good suggestions, thanks!
Sorry for the late reply, please try to test images in different dataset: configs/base/datasets/nway_kshot/few_shot_voc_ms.py
test=dict(
type='FewShotVOCDataset',
# ann_cfg=[
# dict(
# type='ann_file',
# ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'),
# ], # todo
ann_cfg=[
dict(
type='ann_file',
ann_file=data_root +
'VOC2007/ImageSets/Main/trainval.txt'),
],
img_prefix=data_root,
pipeline=test_pipeline,
test_mode=True,
classes=None),
Hello, thank you very much for your reply, I changed the code you provided in few_shot_voc_ms.py and test images in different dataset, but the problem still remains. Maybe KeyError: 'img' means that an element with the key 'img' cannot be found in query_img_metas[0]. This could be because the element in query_img_metas does not contain the 'img' key, or query_img_metas itself is empty. Can you provide some good suggestions, thanks!
Hello, I checked the previous code and you are right. We need to add "img_metas[0]['img'] = img" at:
fpd/fpd_detector.py
def simple_test(self,
img: Tensor,
img_metas: List[Dict],
proposals: Optional[List[Tensor]] = None,
rescale: bool = False):
assert self.with_bbox, 'Bbox head must be implemented.'
assert len(img_metas) == 1, 'Only support single image inference.'
if not self.is_model_init:
# process the saved support features
self.fpd_model_init()
query_feats = self.extract_feat(img)
img_metas[0]['img'] = img # test time visualization
Hello, it is a very meaningful work, but I see that your paper has an Attention heatmap for querying pictures. How can I reproduce it? Is there any relevant script? I am looking forward to the author's reply very much