Closed yjdzyr closed 5 months ago
感谢您的关注~了解到您的需求。可视化特征图的代码并不复杂,只需要在网络前向传播时把中间输出保存下来,后续读取并可视化即可。
不过特征图可视化和网络的训练、推理无关,每个人想要可视化的网络层、可视化方式很难统一,所以我只能给出一份示例代码,您可以fork本仓库按照自己的需求进行修改。
import cv2
import torch
from matplotlib import pyplot as plt
from configs.salience_detr.salience_detr_resnet50_800_1333 import model
from util.utils import load_state_dict
# load model state dict and set to eval mode
weight = torch.load("salience_detr_resnet50_800_1333_coco_2x.pth", map_location="cpu")
load_state_dict(model, weight)
model = model.eval()
# prepare input image
image_name = "data/coco/val2017/000000000139.jpg"
save_name = image_name.replace(".jpg", ".pth")
image = cv2.imread(image_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = torch.tensor(image).permute(2, 0, 1)
# define a forward hook, which saves the output tensor after the forward pass
# for image.jpg, the output tensor will be saved to image.pth
def save_output_hook(module, input, output):
torch.save(output, save_name)
# register the forward hook and save the medium feature map
hook = model.backbone.register_forward_hook(save_output_hook)
model([image_tensor])
hook.remove() # remove the hook after the forward pass
# load the saved feature map and perform visualization
multi_level_features = torch.load(save_name)
plt.figure(figsize=(16, 4))
for key, value in multi_level_features.items():
feat_for_show = value[0].mean(0).detach()
plt.subplot(1, 4, int(key[-1]))
plt.imshow(feat_for_show, cmap="jet")
plt.title(key)
plt.axis("off")
非常感谢作者的解答,解决了问题,期待作者能够继续发布优秀的项目
非常感谢作者的解答,解决了问题,期待作者能够继续发布优秀的项目
非常感谢作者贡献这个优秀的项目,我想请问一下作者可否更新一些关于可视化模型特征图的代码~