jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
https://jacobgil.github.io/pytorch-gradcam-book
MIT License
10.46k stars 1.55k forks source link

when i apply cam to PraNet,it has an error:An exception occurred in CAM with block: <class 'IndexError'>. Message: index 1 is out of bounds for dimension 0 with size 1 #238

Closed haomayang1126 closed 2 years ago

haomayang1126 commented 2 years ago

import argparse import cv2 import matplotlib.pyplot as plt

from lib.PraNet_Res2Net import PraNet import numpy as np import torch from torchvision import models import torchvision.transforms as transforms from PIL import Image import warnings warnings.filterwarnings('ignore') warnings.simplefilter('ignore') from torchvision.models.segmentation import deeplabv3_resnet50 import torch import torch.functional as F from pytorch_grad_cam.utils.model_targets import SemanticSegmentationTarget import requests from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

if name == 'main':

modelpath='PraNet-19.pth'
model = PraNet()
model.load_state_dict(torch.load(modelpath))
target_layers = [model.resnet.layer4]

imagepath='2.png'

img=np.array(Image.open(imagepath))
img = cv2.resize(img, (576, 576))
rgb_img = np.float32(img) / 255
input_tensor = preprocess_image(rgb_img,
                                mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

if torch.cuda.is_available():
    model = model.cuda()
    input_tensor = input_tensor.cuda()

res5, res4, res3, res2 = model(input_tensor)

res = res2.sigmoid().data.cpu().numpy().squeeze()

plopymaskuint8 = 255 * np.uint8(res == 1)
plopymaskfloat = np.float32(res == 1)

from pytorch_grad_cam import GradCAM

targets = [SemanticSegmentationTarget(1, plopymaskfloat)]  #targets是一个类

with GradCAM(model=model,
                   target_layers=target_layers,
                   use_cuda=torch.cuda.is_available()) as cam:

    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets,)[0, :]

    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
    plt.imshow(cam_image)
    plt.show()
ChenHongyu-bo commented 2 years ago

I also encountered this problem. Have you solved it?

wcyjerry commented 2 years ago

targets = [SemanticSegmentationTarget(1, plopymaskfloat)] #targets是一个类 change this into targets = [SemanticSegmentationTarget(0, plopymaskfloat)] #targets是一个类