pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.94k stars 499 forks source link

using captum (integrated gradients) for the CIFAR10 in greyscale as the input for training a resnet18 #378

Closed mrdupadupa closed 4 years ago

mrdupadupa commented 4 years ago

Hi all, I have an issue regarding using Captum for Grayscale CIFAR10 with the ResNet18. I used the example from the tutorial: "Interpreting vision with CIFAR" However, I have errors during the execution, which I could not solve:

Does It mean that for the ResNet I could not use DeepLift method because I'm using ReLU not a once?

I am pretty sure that the errors connected with the "greyscale" attribute of my Cifar dataset(for the RGB it works fine). But I don't know what to change in captum code to adapt it to my data.

Here is the code:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Changing the transform argument for augmentation
transform_trainset = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),    
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])
transform_testset = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])
trainset = torchvision.datasets.CIFAR10(root='/home/andrei/Study/master_thesis/data', train=True, download=True, transform=transform_trainset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='/home/andrei/Study/master_thesis/data', train=False, download=True, transform=transform_testset)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=True, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#Define a Convolutional Neural Network
class MyResNet(nn.Module):
    def __init__(self, in_channels=1):
        super(MyResNet, self).__init__()
        self.model = torchvision.models.resnet18()
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    def forward(self, x):
        return self.model(x)

my_resnet = MyResNet()

input = torch.randn((8,1,32,32))
output = my_resnet(input)
print(output.shape)
net = my_resnet
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
dataiter = iter(trainloader)
images, labels = dataiter.next()
#Train the network
for epoch in range(1):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
#load some images from the test dataset and perform predictions
def imshow(img, one_channel=True):
    img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(npimg, cmap="Greys")
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(8)))
#choose a test image at index ind and apply some of our attribution algorithms on it.
ind = 6
input = images[ind].unsqueeze(0)
input.requires_grad = True
net.eval()
def attribute_image_features(algorithm, input, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=labels[ind],
                                              **kwargs
                                             )
    return tensor_attributions

saliency = Saliency(net)
grads = saliency.attribute(input, target=labels[ind].item())
grads = np.transpose(grads.squeeze().cpu().detach().numpy())
ig = IntegratedGradients(net)
attr_ig, delta = attribute_image_features(ig, input, baselines=input * 0, return_convergence_delta=True)
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy())
print('Approximation delta: ', abs(delta))
#use integrated gradients and noise tunnel with smoothgrad square option on the test image

In [18]:

ig = IntegratedGradients(net)
ig = IntegratedGradients(net)
nt = NoiseTunnel(ig)
attr_ig_nt = attribute_image_features(nt, input, baselines=input * 0, nt_type='smoothgrad_sq',
                                      n_samples=100, stdevs=0.2)
attr_ig_nt = np.transpose(attr_ig_nt.squeeze(0).cpu().detach().numpy())
#Applies DeepLift on test image
dl = DeepLift(net)
attr_dl = attribute_image_features(dl, input, baselines=input * 0)
attr_dl = np.transpose(attr_dl.squeeze(0).cpu().detach().numpy())
#visualize the attributions for Saliency Maps, DeepLift, Integrated Gradients and Integrated Gradients with SmoothGrad
print('Original Image')
print('Predicted:', classes[predicted[ind]], 
      ' Probability:', torch.max(F.softmax(outputs, 1)).item())

original_image = np.transpose((images[ind].cpu().detach().numpy() / 2) + 0.5)

_ = viz.visualize_image_attr(None, original_image, 
                      method="original_image", title="Original Image")

_ = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",
                          show_colorbar=True, title="Overlayed Gradient Magnitudes")

_ = viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

_ = viz.visualize_image_attr(attr_ig_nt, original_image, method="blended_heat_map", sign="absolute_value", 
                             outlier_perc=10, show_colorbar=True, 
                             title="Overlayed Integrated Gradients \n with SmoothGrad Squared")

_ = viz.visualize_image_attr(attr_dl, original_image, method="blended_heat_map",sign="all",show_colorbar=True, 
                          title="Overlayed DeepLift")
#compute attributions using Integrated Gradients and visualize them on the image.
integrated_gradients = ig
pred_label_idx = predicted[ind]
attributions_ig = integrated_gradients.attribute(input, target = pred_label_idx, n_steps=100)
transformed_img = input

@psteinb

bilalsal commented 4 years ago

Hi @mrdupadupa, thank you for bringing this up! We will fix the visualization module to automatically handle gray-scale images.

As a quick fix you can manually reshape the visualization input to fit with what the methods expect, as illustrated in the example below.

original_image = np.zeros([224, 224, 1])
attr = ig.attribute(torch.zeros([1, 1, 224, 224]), target=1) # returns a [1, 1, 224, 224] tensor
attr = attr.squeeze().unsqueeze(2).cpu().detach().numpy() # returns a [224, 224, 1] numpy array

_ = viz.visualize_image_attr(attr, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

_ = viz.visualize_image_attr_multiple(attr,
                                      np.squeeze(original_image),
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      show_colorbar=True)

The DeepLift error seems related to your model and not to the fact that your input is grayscale (I could reproduce it with in_channels = 3). @vivekmig could you have a look?

Hope this helps

vivekmig commented 4 years ago

Hi @mrdupadupa , the issue with DeepLift seems to be due to repeated use of relu in the torchvision ResNet basic block. You can resolve this by copying the ResNet source from here and just replacing the existing BasicBlock with the modifications shown below:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        # Added another relu here
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        # Modified to use relu2
        out = self.relu2(out)

        return out
mrdupadupa commented 4 years ago

Hi, @bilalsal @vivekmig Thank you for your answers. It helps a lot.

psteinb commented 4 years ago

@vivekmig can you shed some light on why DeepLift has issues with a reused layer in the network?

vivekmig commented 4 years ago

Hi @psteinb , sure, this limitation is particular to the current implementation and essentially because intermediate activations for both baselines and inputs need to be stored in the forward pass and used to override the gradient in the backward pass. Currently, those activations are stored as (temporary) attributes on the corresponding modules themselves, so a reused activation causes overwriting the stored temporary attributes. (A similar issue also indirectly affects layer and neuron attribution methods in Captum, they always attribute with respect to the last execution of a reused module.)

In most cases, we may be able to get around this issue with refactoring the implementation to allow storing multiple activations for a single module by keying on the execution count, essentially by separately storing the activations for the 1st, 2nd, 3rd, etc. time the module is executed. We haven't yet worked on this refactor, but we will consider if we can prioritize it for future releases if this is a common issue.

NarineK commented 4 years ago

@psteinb, in a more broader context PyTorch currently does not tell us where exactly in the computation graph an operator is executed. Forward and backward hooks do not provide that information. This affects all hooks including all layer and neuron attributions. In the latter cases you'll receive the attribution with respect to last execution of that hook only.

There are some plans on expanding PyTorch to be able to give access to that information. JIT gives information about graph structure but, unfortunately, the hooks aren't currently supported there but there are some folks working on it.

One way of trying to solve that issue, as Vivek mentioned, is to count how many times a hook gets hit but it can be hack-y(not elegant) to implement and it might be easier to redefine an activation instead of reusing it. It's pretty straightforward to do in PyTorch. I don't think that a refactoring is needed at this point and a more elegant way to solve the problem is to align it with the extended / improved functionality of PyTorch.

unnir commented 3 years ago

to have it in one place:

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        # Added another relu here
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        # Modified to use relu2
        out = self.relu2(out)

        return out

and then you can:

test_model = resnet18(pretrained=False, progress=True)

cdsarto commented 3 years ago

Hey everybody,

I have a similiar issue than reported by @mrdupadupa . I tried to use DeepLift with Reset 50. As recommended I got the source code from the ResNets from [https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py] and adjsuted the Basic Block, but I still get the same Error:

A Module ReLU(inplace=True) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network.<

I also found the comment from @giangnguyen2412 in the #480 where he wrote:

Hi @vivekmig , thanks for your help. I want to add some details into your advice (for those who are new). First, copy the definitions from here and execute them as your source. You should import two funcs load_state_dict_from_url and _get_torch_home from torch.hub. Then define your model as model=resnet50(pretrained=True).eval(). One more thing is that you are not only required to modifed the BasicBlock class but also Bottleneck class to make it work!

But I'm not sure, how to modify the BottleNeck Class and/or why to import the two mentioned funcs.

Thank you in advance :)

shikhar-srivastava commented 3 years ago

@vivekmig's comment solves this. @cdsarto: The comment by @unnir misses the BottleNeck reference in his code. Have a look at https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for the complete source and replace the BasicBlock class with that from @vivekmig's comment here. It's a straightforward change.

Here is the full ResNet src modified to have no reused ReLU blocks.

ericotjo001 commented 2 years ago

Hi, although there are already good answers, I just add a repo to implement the change. Feel free to take a look: https://github.com/ericotjo001/pytorch_captum_fix

Basically, each Bottleneck or BasicBlock is replaced with AdjustedBottleneck and AdjustedBasicBlock, and then DeepLIFT will work.

youyinnn commented 1 year ago

@NarineK @vivekmig The solution seems not to be working anymore for torch 2.1. Especially I am using mac with mps device.