pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.7k stars 475 forks source link

Using Integrated Gradients with Resnet50 #1168

Open jbug13 opened 11 months ago

jbug13 commented 11 months ago

❓ Questions and Help

Hello! I am trying to learn about about Captum and how to use it. I am trying to use some of the Captum model interpretation tools such as Integrated Gradients (and eventually Occlusion). I was able to execute the example provided from the following link:

https://captum.ai/tutorials/Resnet_TorchVision_Interpret

I was then trying to swap out the resnet used for Resnet50: https://pytorch.org/vision/main/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn_v2.html

I see that the input to Integrated Gradients requires a single tensor, but the Resnet50 outputs a list[Dict[Tensor]]. I came across some other posts saying that a forward function for the model may be required. I have been unable to get this to work thus far. This where I am a bit stuck most likely due to a combination of lack of experience and operator error on my part. Currently I am getting an error saying: "Selected k out of index range".

Any advice or direction pointing would be most appreciated.

Below is the code I modified from the original working example for reference.

#!/usr/bin/env python
#coding: utf-8

#Model Interpretation for Pretrained ResNet Model

#This notebook demonstrates how to apply model interpretability algorithms on 
#pretrained ResNet model using a handpicked image and visualizes the 
#attributions for each pixel by overlaying them on the image.
# 
#The interpretation algorithms that we use in this notebook are 
#`Integrated Gradients` (w/ and w/o noise tunnel),  `GradientShap`, and `Occlusion`. 
#A noise tunnel allows to smoothen the attributions after adding gaussian noise to each input sample.
#   
#**Note:** Before running this tutorial, please install the torchvision, PIL, and matplotlib packages.

import torch
import torch.nn.functional as F

from PIL import Image

import os
import json
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

import torchvision
from torchvision import models
from torchvision import transforms

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import Saliency
from captum.attr import visualization as viz

from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

import pandas as pd

################################################################################
##Resnet50 Reference(s):
##https://pytorch.org/vision/main/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn.html#torchvision.models.detection.fasterrcnn_resnet50_fpn
##https://pytorch.org/vision/main/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn_v2.html
## 
##
##Resnet50 Notes:
##During inference, the model requires only the input tensors, and returns the 
##post-processed predictions as a List[Dict[Tensor]], one for each input image. 
##The fields of the Dict are as follows, where N is the number of detections:
##boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
##labels (Int64Tensor[N]): the predicted labels for each detection
##scores (Tensor[N]): the scores of each detection
################################################################################

################################################################################
##Forward Function:
##Reference: https://github.com/pytorch/captum/issues/642
##Captum first runs a forward pass and then attributes the output channel you select to the input features.
##The current implementation does not readily account for the list-typed output your model computes.
#def single_output_forward(out_ind):
#def forward(x):
#yhat = model(x)
#print (yhat.type)
#print ("yhat: ", yhat)
#return yhat[out_ind]
#return forward
################################################################################

def single_output_forward(out_ind):
    def forward(inp):
        print ('inp: ' , inp.shape)
        test = inp[0]
        print ('test.shape: ', test.shape)
        prediction = model(inp)[out_ind]
        print ('prediction: ', prediction)
        output = prediction["scores"].detach()
        print ('output: ', output)
        output2 = output[None,:]
        print ('output2: ', output2)
        print ('output2.shape: ', output2.shape)
        output3 = torch.topk(output2, 1)
        return output3
    return forward

#1- Test Loading the model and the dataset
#Loads pretrained Resnet model and sets it to eval mode

weights_weAreUsing = models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights_weAreUsing, box_score_thresh=0.9, num_classes=91)
model = model.eval()
preprocess = weights_weAreUsing.transforms()

img_path = os.getenv("HOME")+'/dev/img/DEMO1006_19271_18_50_11.jpg'

img_forBoundingBox = read_image(img_path)
img = Image.open(img_path)
#Step 3: Apply inference preprocessing transforms
batch = [preprocess(img_forBoundingBox)]

#Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]

print (prediction)

scores = prediction["scores"].detach().numpy()
labels = [weights_weUsing.meta["categories"][i] for i in prediction["labels"]]
boxes = prediction["boxes"].detach().numpy()

tmp = prediction['labels']
print (weights_weAreUsing.meta['categories'][1])

xmins = []
ymins = []
xmaxs = []
ymaxs = []

for row in boxes:
    xmins.append(row[0])
    ymins.append(row[1])
    xmaxs.append(row[2])
    ymaxs.append(row[3])

d = {'name': labels, 'confidence': scores, 'xmin': xmins, 'ymin': ymins, 'xmax': xmaxs, 'ymax': ymaxs}

df = pd.DataFrame(data=d)

print(df)

box_img = draw_bounding_boxes(img_forBoundingBox, boxes=prediction["boxes"],
                              labels=labels,
                              colors="red",
                              width=4)

#im = to_pil_image(box_img.detach())
#im.show()

#Following Captum Example: https://captum.ai/tutorials/Resnet_TorchVision_Interpret
#Defines transformers and normalizing functions for the image.
#It also loads an image from the `img/` folder that will be used for interpretation purposes.

transform = transforms.Compose([
 transforms.Resize([256,256]),
 transforms.ToTensor()
])

img2 = Image.open(img_path)

transformed_img = transform(img2)

input =preprocess(transformed_img)

input = input.unsqueeze(0)

#2- Gradient-based attribution
#Let's compute attributions using Integrated Gradients and visualize them on 
#the image. Integrated gradients computes the integral of the gradients of the 
#output of the model for the predicted class `pred_label_idx` with respect to 
#the input image pixels along the path from the black image to our input image.

#Reference: https://github.com/pytorch/captum/issues/642
x = input
fwd_fn = single_output_forward(0)
integrated_gradients = IntegratedGradients(fwd_fn)
attributions_ig = integrated_gradients.attribute(x, target = 0, n_steps=100) 

#Let's visualize the image and corresponding attributions by overlaying the latter on the image.
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                     [(0, '#ffffff'),
                                                      (0.25, '#000000'),
                                                      (1, '#000000')], N=256)

_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                                 np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                 method='heat_map',
                                 cmap=default_cmap,
                                 show_colorbar=True,
                                 sign='positive',
                                 outlier_perc=1)
aobo-y commented 10 months ago

So basically IntegratedGradients expects your forward_fun to output a tensor as target to calculate gradients against. We don't have any assumptions of your model architecture or purpose. So return the right tensor to attribute based on what you want to do.

Based on your code, seems you want to explain the most confident detection. I noticed several issues in your code:

First, IntegratedGradients requires gradients, so you should not use output = prediction["scores"].detach() to remove the target from computing graph.

Second, output3 = torch.topk(output2, 1) returns 2 tensors as a tuple, the 1st is the value you may want to attribute to. The 2nd is the indices.

Third, you should omit the target in integrated_gradients.attribute(x, target = 0, n_steps=100). The target is only useful in classification model to make target selection easier. You have already manually selected the target in your forward torch.topk(output2, 1).

jbug13 commented 10 months ago

Thank you very much for your reply and help! I plan on working on this some more today. This is all very new to me so I really appreciate you taking the time to respond.

jbug13 commented 10 months ago

I was able to get some output from the IntegratedGradients and the Occulsion attributes as well. Your comments helped me a great deal. I still have a long ways to go to understanding the data output and if I am employing the functions as intended. I will spend some time in the near future reading the Integrated Gradients paper. Thank you very much for your help!