kazuto1011 / grad-cam-pytorch

PyTorch re-implementation of Grad-CAM (+ vanilla/guided backpropagation, deconvnet, and occlusion sensitivity maps)
MIT License
790 stars 172 forks source link

Use GradCAM with dataloaders #28

Closed Optimox closed 4 years ago

Optimox commented 4 years ago

Hi,

Thank you for this very nice work.

I've been trying to encapsulate GradCam into a single wrapper than can be used like any other model for prediction from dataloaders.

Here is what I did :

import numpy as np
from grad_cam import (GradCAM,
                      BackPropagation)
from tqdm import tqdm

class GradCaMExplainer(torch.nn.Module):
    """
    Creates a torch module for grad cam
    """
    def __init__(self, model, target_layer="_conv_head", topk=1):
        super(GradCaMExplainer, self).__init__()
        self.back_propagator = BackPropagation(model=model)
        self.grad_cam = GradCAM(model=model)
        self.topk = topk
        self.target_layer = target_layer

    def forward(self, x):

        probs, ids = self.back_propagator.forward(x) # sorted
        self.back_propagator.remove_hook()

        _ = self.grad_cam.forward(x)
        self.grad_cam.remove_hook()
        for i in range(self.topk):
            # Grad-CAM
            self.grad_cam.backward(ids=ids[:, [i]])
            regions = self.grad_cam.generate(target_layer=self.target_layer)
        return regions

def gradcam_explain(grad_explainer, dataloader, with_target=False, device='cuda'):
    """
    This outputs explanations for an entire dataloader
    """
    res_region = []
    res_probs = []
    res_ids = []

    for batch in tqdm(dataloader):
        if with_target:
            inputs, targets = batch
        else:
            inputs = batch
        inputs = inputs.to(device)
        regions = grad_explainer(inputs)
        res_region.append(regions.to("cpu").numpy())
    return np.vstack(res_region)

But something strange is happening : imagine I have 119 examples in my dataloader, and set my batch size to 20, then the last batch only has 19 examples however the regions results from the grad_explainer gives a tensor with 20 examples.

I suspect this might be due to sizes being set a init_time and only once or something like this, so the output is always the size of the first batch but could not find by looking carefully at the code. Am I doing something wrong? Could you please help me with this code?

Thanks a lot!

kazuto1011 commented 4 years ago

The size of the last inputs is 19? I suspect the option drop_last=True is enabled in dataloader.

Optimox commented 4 years ago

I want to use the last batch, so drop_last=False in the data loader.

But what happens is while the number of input examples is 19 for the last batch, the output of gradcam is still 20, so I have more gradcam heat maps than number of samples.

I’ll try to create a minimal code to reproduce what I see.

Optimox commented 4 years ago

@kazuto1011 here is a more detailed description of my problem and a minimal code snippet to reproduce the error:

I've been trying to create a small wrapper to apply GradCam as any other model through a data loader.

But I have a problem with dimension outputs, GradCam is outputing wrong dimensions for the last batch, I don't understand why.

Here is bellow a minimal reproducible code that shows my problem. Could you please help me understanding what I'm doing wrong either in GradCaMExplainer ?

# import some modules

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn as nn
from torch.nn import functional as F
import numpy as np

################################################################
############ EXACT COPY OF CURRENT REPO ########################
class _BaseWrapper(object):
    def __init__(self, model):
        super(_BaseWrapper, self).__init__()
        self.device = next(model.parameters()).device
        self.model = model
        self.handlers = []  # a set of hook function handlers

    def _encode_one_hot(self, ids):
        one_hot = torch.zeros_like(self.logits).to(self.device)
        one_hot.scatter_(1, ids, 1.0)
        return one_hot

    def forward(self, image):
        self.image_shape = image.shape[2:]
        self.logits = self.model(image)
        self.probs = F.softmax(self.logits, dim=1)
        return self.probs.sort(dim=1, descending=True)  # ordered results

    def backward(self, ids):
        """
        Class-specific backpropagation
        """
        one_hot = self._encode_one_hot(ids)
        self.model.zero_grad()
        self.logits.backward(gradient=one_hot, retain_graph=True)

    def generate(self):
        raise NotImplementedError

    def remove_hook(self):
        """
        Remove all the forward/backward hook functions
        """
        for handle in self.handlers:
            handle.remove()

class BackPropagation(_BaseWrapper):
    def forward(self, image):
        self.image = image.requires_grad_()
        return super(BackPropagation, self).forward(self.image)

    def generate(self):
        gradient = self.image.grad.clone()
        self.image.grad.zero_()
        return gradient

class GradCAM(_BaseWrapper):
    """
    "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
    https://arxiv.org/pdf/1610.02391.pdf
    Look at Figure 2 on page 4
    """

    def __init__(self, model, candidate_layers=None):
        super(GradCAM, self).__init__(model)
        self.fmap_pool = {}
        self.grad_pool = {}
        self.candidate_layers = candidate_layers  # list

        def save_fmaps(key):
            def forward_hook(module, input, output):
                self.fmap_pool[key] = output.detach()

            return forward_hook

        def save_grads(key):
            def backward_hook(module, grad_in, grad_out):
                self.grad_pool[key] = grad_out[0].detach()

            return backward_hook

        # If any candidates are not specified, the hook is registered to all the layers.
        for name, module in self.model.named_modules():
            if self.candidate_layers is None or name in self.candidate_layers:
                self.handlers.append(module.register_forward_hook(save_fmaps(name)))
                self.handlers.append(module.register_backward_hook(save_grads(name)))

    def _find(self, pool, target_layer):
        if target_layer in pool.keys():
            return pool[target_layer]
        else:
            raise ValueError("Invalid layer name: {}".format(target_layer))

    def generate(self, target_layer):
        fmaps = self._find(self.fmap_pool, target_layer)
        grads = self._find(self.grad_pool, target_layer)
        weights = F.adaptive_avg_pool2d(grads, 1)
        gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
        gcam = F.relu(gcam)
        gcam = F.interpolate(
            gcam, self.image_shape, mode="bilinear", align_corners=False
        )

        B, C, H, W = gcam.shape
        gcam = gcam.view(B, -1)
        gcam -= gcam.min(dim=1, keepdim=True)[0]
        gcam /= gcam.max(dim=1, keepdim=True)[0]
        gcam = gcam.view(B, C, H, W)

        return gcam
################################################################

class GradCaMExplainer(torch.nn.Module):
    """
    Creates a torch module for grad cam
    """
    def __init__(self, model, target_layer="layer4.1.conv2", topk=1):
        super(GradCaMExplainer, self).__init__()
        self.back_propagator = BackPropagation(model=model)
        self.grad_cam = GradCAM(model=model)
        self.topk = topk
        self.target_layer = target_layer

    def forward(self, x):
        probs, ids = self.back_propagator.forward(x) # sorted
        self.back_propagator.remove_hook()        

        _ = self.grad_cam.forward(x)
        self.grad_cam.remove_hook()
        for i in range(self.topk):
            # Grad-CAM
            self.grad_cam.backward(ids=ids[:, [i]])
            regions = self.grad_cam.generate(target_layer=self.target_layer)       
        return regions

def gradcam_explain(grad_explainer, dataloader, with_target=False, device='cpu'):
    """
    This outputs explanations for an entire dataloader
    """
    res_region = []
    res_probs = []
    res_ids = []
    c = 0
    c2 = 0
    for batch in dataloader:
        if with_target:
            inputs, targets = batch
        else:
            inputs = batch
        c += inputs.shape[0]
        inputs = inputs.to(device)
        regions = grad_explainer(inputs)
        print("batch regions", regions.shape)
        print("----------")
        c2 += regions.shape[0]
        res_region.append(regions.to("cpu").numpy())
    print("total element ", c)
    print("total regions ", c2)
    return np.vstack(res_region)

# Let's take any pretrained model
model = models.resnet18(pretrained=True)

# Let's create 10 random images
random_images = torch.rand(10, 3, 224, 224)

class BasicDataset(Dataset):
    """
    The simplest dataset you can imagine
    """
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx, :, :, :]

dataset = BasicDataset(random_images)
# set batch size to 8 and drop last to False so that first batch is 8 and second 2
dataloader = DataLoader(dataset, batch_size=8, drop_last=False)

## check sizes
print("Check batch sizes")
for batch in dataloader:
    print(batch.shape)
print("---------------")

grad_explainer = GradCaMExplainer(model)

regions = gradcam_explain(grad_explainer, dataloader, with_target=False, device='cpu')

As you will see if you run this code snippet it will output :

Check batch sizes
torch.Size([8, 3, 224, 224])
torch.Size([2, 3, 224, 224])
---------------
batch regions torch.Size([8, 1, 224, 224])
----------
batch regions torch.Size([8, 1, 224, 224])
----------
total element  10
total regions  16

So for a dataloader with 10 elements and a batch size of 8, I'll end up with 16 explained regions by gradcam instead of 10. The mismatch happens during the last batch but I can't understand why...

I would really appreciate your help!

kazuto1011 commented 4 years ago

Thank you for the minimal codes! The problem is self.grad_cam.remove_hook() which unregisters save_fmaps() and save_grads() from the model; the region maps are never updated from the second run. It can be solved by simply removing the line and the redundant self.back_propagator, like this:

class GradCaMExplainer(torch.nn.Module):

    def __init__(self, model, target_layer="layer4.1.conv2", topk=1):
        super(GradCaMExplainer, self).__init__()
        self.grad_cam = GradCAM(model=model)
        self.topk = topk
        self.target_layer = target_layer

    def forward(self, x):
        probs, ids = self.grad_cam.forward(x)
        regions = {}
        for i in range(self.topk):
            self.grad_cam.backward(ids=ids[:, [i]])
            regions[i] = self.grad_cam.generate(target_layer=self.target_layer)
        return regions

Please tell me if still stuck with the problem.

Optimox commented 4 years ago

thanks a lot @kazuto1011 it's working.

I thought gradcam had to be used on top of Backproagator. Also could you explain when should remove_hooks be used?

kazuto1011 commented 4 years ago

I thought gradcam had to be used on top of Backproagator.

Each wrapper module can be used independently. They have the same pipeline: forward, backward, and generate (shortly explained here).

Also could you explain when should remove_hooks be used?

Please call the function, especially when destructing GuidedBackPropagation and Deconvnet classes. They modify the given model instance with the hook function to filter the ReLU gradients on runtime, which is not required for the other algorithms. With remove_hooks, you can restore the model.

deconv = Deconvnet(model=model) # hook
...
deconv.remove_hook() # otherwise the hook leaks below
gcam = GradCAM(model=model)

Or

import copy
deconv = Deconvnet(model=copy.deepcopy(model))
gcam = GradCAM(model=model)
...
gcam = GradCAM(model=model)
deconv = Deconvnet(model=model)  # hook
...
Optimox commented 4 years ago

Thanks very much!

I've seen the code on the link bellow, I think what got me confused is that are used every where are defined only at line 160 by probs, ids = bp.forward(images) # sorted, so I thought you had to use the back propagator.

Maybe you could add some comments to make this clearer.

Anyway, thanks very much again! Your help was much appreciated!

Bests