yizt / Grad-CAM.pytorch

pytorch实现Grad-CAM和Grad-CAM++,可以可视化任意分类网络的Class Activation Map (CAM)图,包括自定义的网络;同时也实现了目标检测faster r-cnn和retinanet两个网络的CAM图;欢迎试用、关注并反馈问题...
Apache License 2.0
715 stars 166 forks source link

关于gradcam在异常检测场景的应用问题 #55

Open EricLee0224 opened 2 years ago

EricLee0224 commented 2 years ago

作者您好, 首先感谢您的代码贡献,非常简洁,关键注释非常清晰!已按照readme已经成功跑通示例~目前希望依托您的代码框架,进一步想试一试引入自己的预训练网络,生成gradcam,进行图像异常检测。现有一个基于Res18预训练模型,前面添加了head,后面添加了几层额外的卷积层和fc层(最终输出分别是0:正常、1:异常),对自己的正常数据集进行无监督学习训练得到权重,用于对异常图片进行异常检测。然后利用gradcam在异常图像上标注出异常的位置。 现有问题是如何将前面提到的自己的模型权重引入框架?自己试了试之后会有如下报错,请问可能是什么问题?

(gradcam)_____@server3090-X570-AORUS-PRO-WIFI:~/Grad-CAM.pytorch-master$ python main.py feature shape:torch.Size([1, 512, 7, 7]) /home/____/.conda/envs/gradcam/lib/python3.8/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " feature shape:torch.Size([1, 512, 7, 7])

ps.最终确实能够生成图,但明显不是基于我自己的模型来生成的。 以下是自己修改后的main.py:

-- coding: utf-8 --

import argparse import os import re

import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet18 import cv2 import numpy as np import torch from skimage import io from torch import nn from torchvision import models

from interpretability.grad_cam import GradCAM, GradCamPlusPlus from interpretability.guided_back_propagation import GuidedBackPropagation

def get_net(net_name, weight_path=None): """ 根据名称获取模型 :param net_name: 网络名称 :param weight_path: 与训练权重路径 :return: """ pretrain = weight_path is None # 没有指定权重路径,加载默认的预训练权重 if net_name in ['vgg', 'vgg16']: net = models.vgg16(pretrained=pretrain) elif net_name in ['resnet', 'resnet18']: net = models.resnet18(pretrained=pretrain) else: raise ValueError('invalid network name:{}'.format(net_name))

加载指定路径的权重参数

if weight_path is not None and net_name.startswith('densenet'):
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
    state_dict = torch.load(weight_path)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    net.load_state_dict(state_dict)
elif weight_path is not None:
    net.load_state_dict({k.replace('resnet18.',''):v for k,v in torch.load(weight_path).items()},strict=False)
return net

def get_last_conv_name(net): """ 获取网络的最后一个卷积层的名字 :param net: :return: """ layer_name = None for name, m in net.named_modules(): if isinstance(m, nn.Conv2d): layer_name = name return layer_name

def prepare_input(image): image = image.copy()

# 归一化
#means = np.array([0.485, 0.456, 0.406])
#stds = np.array([0.229, 0.224, 0.225])
#image -= means
#image /= stds

image = np.ascontiguousarray(np.transpose(image, (2, 0, 1)))  # channel first
image = image[np.newaxis, ...]  # 增加batch维

return torch.tensor(image, requires_grad=True)

def gen_cam(image, mask): """ 生成CAM图 :param image: [H,W,C],原始图像 :param mask: [H,W],范围0~1 :return: tuple(cam,heatmap) """

mask转为heatmap

heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
heatmap = heatmap[..., ::-1]  # gbr to rgb

# 合并heatmap到原始图像
cam = heatmap + np.float32(image)
return norm_image(cam), (heatmap * 255).astype(np.uint8)

def norm_image(image): """ 标准化图像 :param image: [H,W,C] :return: """ image = image.copy() image -= np.max(np.min(image), 0) image /= np.max(image) image *= 255. return np.uint8(image)

def gen_gb(grad): """ 生guided back propagation 输入图像的梯度 :param grad: tensor,[3,H,W] :return: """

标准化

grad = grad.data.numpy()
gb = np.transpose(grad, (1, 2, 0))
return gb

def save_image(image_dicts, input_image_name, network, output_dir): prefix = os.path.splitext(input_image_name)[0] for key, image in image_dicts.items(): io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)

def main(args):

输入

img = io.imread(args.image_path)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = np.float32(cv2.resize(img, (224, 224))) / 255
inputs = prepare_input(img)
# 输出图像
image_dict = {}
# 网络
net = get_net(args.network, args.weight_path)
# Grad-CAM
layer_name = get_last_conv_name(net) if args.layer_name is None else args.layer_name
grad_cam = GradCAM(net, layer_name)
mask = grad_cam(inputs, args.class_id)  # cam mask
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
grad_cam.remove_handlers()
# Grad-CAM++
grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
mask_plus_plus = grad_cam_plus_plus(inputs, args.class_id)  # cam mask
image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
grad_cam_plus_plus.remove_handlers()

# GuidedBackPropagation
gbp = GuidedBackPropagation(net)
inputs.grad.zero_()  # 梯度置零
grad = gbp(inputs)

gb = gen_gb(grad)
image_dict['gb'] = norm_image(gb)
# 生成Guided Grad-CAM
cam_gb = gb * mask[..., np.newaxis]
image_dict['cam_gb'] = norm_image(cam_gb)

save_image(image_dict, os.path.basename(args.image_path), args.network, args.output_dir)

if name == 'main': parser = argparse.ArgumentParser() parser.add_argument('--network', type=str, default='resnet18', help='ImageNet classification network') parser.add_argument('--image-path', type=str, default='./Cutpaste_examples/icecream/XGQK_test.jpg', help='input image path') parser.add_argument('--weight-path', type=str, default='./Cutpaste_examples/icecream/model-icecream-cutpaste-normal.pth', help='weight path of the model') parser.add_argument('--layer-name', type=str, default=None, help='last convolutional layer name') parser.add_argument('--class-id', type=int, default=None, help='class id') parser.add_argument('--output-dir', type=str, default='results', help='output directory to save results') arguments = parser.parse_args()

main(arguments)