chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 32 forks source link

ResNet: Unstable Attributions #194

Open chr5tphr opened 11 months ago

chr5tphr commented 11 months ago

With the introduction of #185 , ResNet18 attributions result in negative attribution sums in the input layer, leading to bad attributions. Although #185 increased the stability of the attribution sums for ResNet, the previous instability seems to have inflated the positive parts of the attributions, circumventing this problem pre #185.

This seems to be related to leaking attributions (#193 ) combined with skip connections that can cause negative attributions.

A quickfix for EpsilonGammaBox is to use a slightly higher gamma value.

MikiFER commented 8 months ago

Hi @chr5tphr is there any news on this bug? I noticed that when setting zero_params='bias' for EpsilonPlusFlat composite total attribution sums to ~1 but sometimes negative attribution may occur and that negative attribution sometimes is small but sometimes it accounts to ~-1 and more so its significant and positive part of the attribution in that case accounts to almost ~2 to negate the effect. Here is code I used to replicate this issue (multiple runs may be required):

import torch
import torch.nn as nn
from torchvision.models import resnet34
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
import matplotlib.pyplot as plt

model = resnet34(weights=None)

# create a composite, specifying the canonizers
composite = EpsilonPlusFlat(canonizers=[ResNetCanonizer()], zero_params='bias')
target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True

with composite.context(model) as modified_model:
    output = modified_model(input_data)
    attribution, = torch.autograd.grad(output, input_data, target)

relevance = attribution.cpu().sum(1).squeeze(0)

if torch.any(relevance < 0):
    print(relevance[relevance < 0].sum())
    print(relevance[relevance > 0].sum())
    print(relevance.sum())
plt.imshow(relevance.numpy())
plt.show()

If you maybe don't have the time to deal with this issue could you maybe point me in the right direction so I could try to fix it since it is critical for me to fix this issue in order to continue with my research.

chr5tphr commented 7 months ago

Hey @MikiFER

I am not sure what you are experiencing is concerning this issue, as the relevance still sums to 1. The bug may be a little bit elusive, but in general it is okay if there is negative relevance (for EpsilonPlusFlat anyways). You should be fine on 0.5.1, the bug referred to in this issue is only on master, caused by #185. Did you use master?

Here's a little snippet to check the model relevance in detail. I have added an extra rule to switch off the residual branch to circumvent the LRP-instabilities discussed in #148. The canonizer is simply ineffective for vgg11, which does not need one (although, since ResNetCanonizer includes MergeBatchNorm, vgg11bn would also work):

Snippet to check relevances ```python from itertools import islice import torch import torch.nn as nn from torchvision.models import resnet34, vgg11 from zennit.torchvision import ResNetCanonizer from zennit.composites import EpsilonPlusFlat from zennit.core import Hook from zennit.layer import Sum class SumSingle(Hook): def __init__(self, dim=1): super().__init__() self.dim = dim def backward(self, module, grad_input, grad_output): elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1]) elems[self.dim] = grad_output[0] return (torch.stack(elems, dim=-1),) def store_hook(module, input, output): module.output = output output.retain_grad() models = {'resnet34': resnet34, 'vgg11': vgg11} for model_name, model_fn in models.items(): torch.manual_seed(0xdeadbeef + 3) model = model_fn(weights=None) model.eval() # create a composite, specifying the canonizers composite = EpsilonPlusFlat( layer_map=[(Sum, SumSingle(1))], canonizers=[ResNetCanonizer()], zero_params='bias' ) target = torch.eye(1000)[[437]] input_data = torch.rand(1, 3, 224, 224) input_data.requires_grad = True with composite.context(model) as modified_model: handles = [ module.register_forward_hook(store_hook) for module in model.modules() if not list(islice(module.children(), 1)) ] output = modified_model(input_data) attribution, = torch.autograd.grad(output, input_data, target) relevance = attribution.cpu().sum(1).squeeze(0) labels = [('input', 'input', attribution)] + [ (name, type(module).__name__, module.output.grad) for name, module in model.named_modules() if hasattr(module, 'output') ] maxname, maxclsname = [max(len(obj[i]) for obj in labels) for i in (0, 1)] print(f'\nModel: {model_name}') for name, clsname, grad in labels: print( f' {name:<{maxname}s} ({clsname:<{maxclsname}s}): ' f'min: {grad.min():+.7f}, ' f'max: {grad.max():+.7f}, ' f'sum: {grad.sum():+.7f}' ) ```

And this is the output I get on 0.5.1:

Output on `0.5.1` ```text Model: resnet34 input (input ): min: +0.0000003, max: +0.0000542, sum: +0.9999950 conv1 (Conv2d ): min: +0.0000000, max: +0.0001266, sum: +0.9999949 bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0001266, sum: +0.9999949 relu (ReLU ): min: +0.0000000, max: +0.0001266, sum: +0.9999949 maxpool (MaxPool2d ): min: +0.0000000, max: +0.0000341, sum: +0.9999950 layer1.0.conv1 (Conv2d ): min: +0.0000000, max: +0.0000796, sum: +0.9999954 layer1.0.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0000796, sum: +0.9999954 layer1.0.relu (ReLU ): min: +0.0000000, max: +0.0000468, sum: +0.9999962 layer1.0.conv2 (Conv2d ): min: +0.0000000, max: +0.0000468, sum: +0.9999962 layer1.0.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0000468, sum: +0.9999962 layer1.1.conv1 (Conv2d ): min: +0.0000000, max: +0.0000834, sum: +0.9999965 layer1.1.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0000834, sum: +0.9999965 layer1.1.relu (ReLU ): min: +0.0000000, max: +0.0000536, sum: +0.9999971 layer1.1.conv2 (Conv2d ): min: +0.0000000, max: +0.0000536, sum: +0.9999971 layer1.1.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0000536, sum: +0.9999971 layer1.2.conv1 (Conv2d ): min: +0.0000000, max: +0.0000902, sum: +0.9999974 layer1.2.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0000902, sum: +0.9999974 layer1.2.relu (ReLU ): min: +0.0000000, max: +0.0001007, sum: +0.9999977 layer1.2.conv2 (Conv2d ): min: +0.0000000, max: +0.0001007, sum: +0.9999977 layer1.2.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0001007, sum: +0.9999977 layer2.0.conv1 (Conv2d ): min: +0.0000000, max: +0.0001177, sum: +0.9999980 layer2.0.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0001177, sum: +0.9999980 layer2.0.relu (ReLU ): min: +0.0000000, max: +0.0001770, sum: +0.9999983 layer2.0.conv2 (Conv2d ): min: +0.0000000, max: +0.0001770, sum: +0.9999983 layer2.0.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0001770, sum: +0.9999983 layer2.0.downsample.0 (Conv2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer2.0.downsample.1 (BatchNorm2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer2.1.conv1 (Conv2d ): min: +0.0000000, max: +0.0001312, sum: +0.9999984 layer2.1.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0001312, sum: +0.9999984 layer2.1.relu (ReLU ): min: +0.0000000, max: +0.0001304, sum: +0.9999986 layer2.1.conv2 (Conv2d ): min: +0.0000000, max: +0.0001304, sum: +0.9999986 layer2.1.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0001304, sum: +0.9999986 layer2.2.conv1 (Conv2d ): min: +0.0000000, max: +0.0001452, sum: +0.9999987 layer2.2.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0001452, sum: +0.9999987 layer2.2.relu (ReLU ): min: +0.0000000, max: +0.0001383, sum: +0.9999988 layer2.2.conv2 (Conv2d ): min: +0.0000000, max: +0.0001383, sum: +0.9999988 layer2.2.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0001383, sum: +0.9999988 layer2.3.conv1 (Conv2d ): min: +0.0000000, max: +0.0001657, sum: +0.9999989 layer2.3.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0001657, sum: +0.9999989 layer2.3.relu (ReLU ): min: +0.0000000, max: +0.0001711, sum: +0.9999990 layer2.3.conv2 (Conv2d ): min: +0.0000000, max: +0.0001711, sum: +0.9999990 layer2.3.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0001711, sum: +0.9999990 layer3.0.conv1 (Conv2d ): min: +0.0000000, max: +0.0003639, sum: +0.9999990 layer3.0.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0003639, sum: +0.9999990 layer3.0.relu (ReLU ): min: +0.0000000, max: +0.0002886, sum: +0.9999992 layer3.0.conv2 (Conv2d ): min: +0.0000000, max: +0.0002886, sum: +0.9999992 layer3.0.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0002886, sum: +0.9999992 layer3.0.downsample.0 (Conv2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer3.0.downsample.1 (BatchNorm2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer3.1.conv1 (Conv2d ): min: +0.0000000, max: +0.0003569, sum: +0.9999993 layer3.1.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0003569, sum: +0.9999993 layer3.1.relu (ReLU ): min: +0.0000000, max: +0.0003025, sum: +0.9999993 layer3.1.conv2 (Conv2d ): min: +0.0000000, max: +0.0003025, sum: +0.9999993 layer3.1.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0003025, sum: +0.9999993 layer3.2.conv1 (Conv2d ): min: +0.0000000, max: +0.0002958, sum: +0.9999993 layer3.2.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0002958, sum: +0.9999993 layer3.2.relu (ReLU ): min: +0.0000000, max: +0.0002067, sum: +0.9999994 layer3.2.conv2 (Conv2d ): min: +0.0000000, max: +0.0002067, sum: +0.9999994 layer3.2.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0002067, sum: +0.9999994 layer3.3.conv1 (Conv2d ): min: +0.0000000, max: +0.0003350, sum: +0.9999993 layer3.3.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0003350, sum: +0.9999993 layer3.3.relu (ReLU ): min: +0.0000000, max: +0.0002595, sum: +0.9999993 layer3.3.conv2 (Conv2d ): min: +0.0000000, max: +0.0002595, sum: +0.9999993 layer3.3.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0002595, sum: +0.9999993 layer3.4.conv1 (Conv2d ): min: +0.0000000, max: +0.0003304, sum: +0.9999993 layer3.4.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0003304, sum: +0.9999993 layer3.4.relu (ReLU ): min: +0.0000000, max: +0.0002727, sum: +0.9999993 layer3.4.conv2 (Conv2d ): min: +0.0000000, max: +0.0002727, sum: +0.9999993 layer3.4.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0002727, sum: +0.9999993 layer3.5.conv1 (Conv2d ): min: +0.0000000, max: +0.0004377, sum: +0.9999994 layer3.5.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0004377, sum: +0.9999994 layer3.5.relu (ReLU ): min: +0.0000000, max: +0.0004310, sum: +0.9999993 layer3.5.conv2 (Conv2d ): min: +0.0000000, max: +0.0004310, sum: +0.9999993 layer3.5.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0004310, sum: +0.9999993 layer4.0.conv1 (Conv2d ): min: +0.0000000, max: +0.0007261, sum: +0.9999992 layer4.0.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0007261, sum: +0.9999992 layer4.0.relu (ReLU ): min: +0.0000000, max: +0.0007178, sum: +0.9999994 layer4.0.conv2 (Conv2d ): min: +0.0000000, max: +0.0007178, sum: +0.9999994 layer4.0.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0007178, sum: +0.9999994 layer4.0.downsample.0 (Conv2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer4.0.downsample.1 (BatchNorm2d ): min: +0.0000000, max: +0.0000000, sum: +0.0000000 layer4.1.conv1 (Conv2d ): min: +0.0000000, max: +0.0008504, sum: +0.9999994 layer4.1.bn1 (BatchNorm2d ): min: +0.0000000, max: +0.0008504, sum: +0.9999994 layer4.1.relu (ReLU ): min: +0.0000000, max: +0.0006153, sum: +0.9999993 layer4.1.conv2 (Conv2d ): min: +0.0000000, max: +0.0006153, sum: +0.9999993 layer4.1.bn2 (BatchNorm2d ): min: +0.0000000, max: +0.0006153, sum: +0.9999993 layer4.2.conv1 (Conv2d ): min: -0.0005434, max: +0.0020294, sum: +0.9999993 layer4.2.bn1 (BatchNorm2d ): min: -0.0005434, max: +0.0020294, sum: +0.9999993 layer4.2.relu (ReLU ): min: -0.0092405, max: +0.0115777, sum: +0.9999994 layer4.2.conv2 (Conv2d ): min: -0.0092405, max: +0.0115777, sum: +0.9999994 layer4.2.bn2 (BatchNorm2d ): min: -0.0092405, max: +0.0115777, sum: +0.9999994 avgpool (AdaptiveAvgPool2d): min: -0.2474068, max: +0.3167669, sum: +0.9999998 fc (Linear ): min: +0.0000000, max: +1.0000000, sum: +1.0000000 Model: vgg11 input (input ): min: -0.0002063, max: +0.0002277, sum: +0.9999397 features.0 (Conv2d ): min: -0.0000620, max: +0.0000807, sum: +0.9999397 features.1 (ReLU ): min: -0.0000620, max: +0.0000807, sum: +0.9999397 features.2 (MaxPool2d ): min: -0.0000620, max: +0.0000807, sum: +0.9999397 features.3 (Conv2d ): min: -0.0001189, max: +0.0001306, sum: +0.9999405 features.4 (ReLU ): min: -0.0001189, max: +0.0001306, sum: +0.9999405 features.5 (MaxPool2d ): min: -0.0001189, max: +0.0001306, sum: +0.9999405 features.6 (Conv2d ): min: -0.0001095, max: +0.0001081, sum: +0.9999412 features.7 (ReLU ): min: -0.0001095, max: +0.0001081, sum: +0.9999412 features.8 (Conv2d ): min: -0.0002049, max: +0.0002393, sum: +0.9999417 features.9 (ReLU ): min: -0.0002049, max: +0.0002393, sum: +0.9999417 features.10 (MaxPool2d ): min: -0.0002049, max: +0.0002393, sum: +0.9999418 features.11 (Conv2d ): min: -0.0001591, max: +0.0001764, sum: +0.9999424 features.12 (ReLU ): min: -0.0001591, max: +0.0001764, sum: +0.9999424 features.13 (Conv2d ): min: -0.0006126, max: +0.0006749, sum: +0.9999429 features.14 (ReLU ): min: -0.0006126, max: +0.0006749, sum: +0.9999429 features.15 (MaxPool2d ): min: -0.0006126, max: +0.0006749, sum: +0.9999428 features.16 (Conv2d ): min: -0.0031778, max: +0.0024524, sum: +0.9999434 features.17 (ReLU ): min: -0.0031778, max: +0.0024524, sum: +0.9999434 features.18 (Conv2d ): min: -0.0809088, max: +0.0687538, sum: +0.9999435 features.19 (ReLU ): min: -0.0809088, max: +0.0687538, sum: +0.9999435 features.20 (MaxPool2d ): min: -0.0809088, max: +0.0687538, sum: +0.9999434 avgpool (AdaptiveAvgPool2d): min: -0.0809091, max: +0.0687540, sum: +0.9999459 classifier.0 (Linear ): min: -0.1945583, max: +0.1917277, sum: +0.9999526 classifier.1 (ReLU ): min: -0.1945583, max: +0.1917277, sum: +0.9999526 classifier.2 (Dropout ): min: -0.1945583, max: +0.1917277, sum: +0.9999526 classifier.3 (Linear ): min: -0.1157313, max: +0.1835920, sum: +0.9999599 classifier.4 (ReLU ): min: -0.1157313, max: +0.1835920, sum: +0.9999599 classifier.5 (Dropout ): min: -0.1157313, max: +0.1835920, sum: +0.9999599 classifier.6 (Linear ): min: +0.0000000, max: +1.0000000, sum: +1.0000000 ```

If you are on master: Do you have a more concrete example of the bug? This would help me to pin the issue down. E.g., the snippet above should produce relevance sums not equal to one. You can try to comment out the line with the layer_map= in the composite to enable relevance from the residual connection.

Maybe skipping the residual branch already fixes your issue?

MikiFER commented 7 months ago

Hi @chr5tphr thank you for your response. I have since discovered that I had a misunderstanding about the EpsilonPlusFlat composite and have decided that pure Alpha1Beta0 rule is actually what I need because I want to obtain relevance map with only positive influence in the input.