rachtibat / zennit-crp

An eXplainable AI toolkit with Concept Relevance Propagation and Relevance Maximization
https://www.nature.com/articles/s42256-023-00711-8
Other
110 stars 14 forks source link

Conditional Heatmaps ignore parallel connections in, e.g., Resnets #16

Closed rachtibat closed 1 year ago

rachtibat commented 1 year ago

When defining for example the condition set [{"features.40": [0, 2]}] the channels 0 and 2 are passed through and all other channels are masked with zero. However, in models with several parallel connections (shortcuts), the parallel connections are not set to zero and relevance is passed through.

In future, all concepts in parallel layers should also be set to zero, to get the sole contribution of the masked concept.

Krystian-Krol commented 1 year ago

As far as I know zennit library requires modified forward function for ResNet to work. This is because it requires custom layer for summing skip value. I found your project very useful when explaining my ResNet model and all I needed to do to make explanations work was to create custom Canonizer as below. I believe the workaround to manually select names of layers which we want to skip while calculating gradient should solve the problem, but correct me if I'm wrong. Sharing my workaround below in case anyone finds it useful, the parameter for class init is the list of layers (blocks which have residual connection not conv layers), for which we don't want gradient to flow through skip-connection.

import torch 

from zennit.canonizers import SequentialMergeBatchNorm, AttributeCanonizer, CompositeCanonizer
from zennit.torchvision import ResNetBottleneckCanonizer, ResNetBasicBlockCanonizer
from torchvision.models.resnet import Bottleneck as ResNetBottleneck
from zennit.layer import Sum

class HackBottleneckCanonizer(ResNetBottleneckCanonizer):
    def __init__(self, overwrite_names):
        AttributeCanonizer.__init__(self, self.get_attribute_map(overwrite_names))

    @classmethod
    def get_attribute_map(cls, overwrite_names):

        def _attribute_map(name, module):
            if isinstance(module, ResNetBottleneck):
                if name in overwrite_names:
                    attributes = {
                        'forward': cls.forward_no_grad.__get__(module),
                        'canonizer_sum': Sum(),
                    }
                    return attributes
                else:
                    attributes = {
                        'forward': cls.forward.__get__(module),
                        'canonizer_sum': Sum(),
                    }
                    return attributes
                return None
        return _attribute_map

    @staticmethod
    def forward_no_grad(self, x):
        '''
        Modified Bottleneck forward for HackResNet.
        This forward doesn't propagate gradient through skip connections of given layers.
        '''
        identity = x.clone().detach()

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = torch.stack([identity, out], dim=-1)
        out = self.canonizer_sum(out)

        out = self.relu(out)

        return out

class HackCanonizer(CompositeCanonizer):
    def __init__(self, grad_omit_skips):
        super().__init__((
            SequentialMergeBatchNorm(),
            HackBottleneckCanonizer(grad_omit_skips),
            ResNetBasicBlockCanonizer(),
        ))
rachtibat commented 1 year ago

Hey,

I'm glad you were able to benefit from the project. Thanks for your effort and I like the creative idea - even if I haven't tested it yet. At the moment, I'm trying out another solution where the torch autograd graph is only partially differentiated, so that no additional overhead is generated for the users. But it will take a few more weeks and maybe it will never work.

There are two more simple methods to isolate the conditional heatmaps:

  1. If we know the names of the parallel layers, we could write the following condition set to analyze the channel 2 in layer "my_layer" while having the layer "parallel_layer" as parallel layer to "my_layer":

condition = [{"my_layer": [2], "parallel_layer": [], "y": [class]}]

This way, the gradient is set to zero in the parallel layers.

  1. If we don't know the names, we could perform two backward passes:
condition = [{"my_layer": [2], "y": [class]}] 
attr = attribution(model, composite, condition)

condition = [{"my_layer": [], "y": [class]}] 
attr2 = attribution(model, composite, condition)

final_heatmap = attr.heatmap -  attr2.heatmap 

This way, we remove the influence of the parallel layers.

Best wishes

rachtibat commented 1 year ago

Hey,

we tried to implement a partial backward pass with PyTorch, but then we would have to limit the flexibility of the tool, which in the end I was not willing to do. That's why we settled with the following solution:

Because of the AttributionGraph in crp.graph, I already implemented a method to trace models. Now, I implemented a new helper function that takes in condition sets and outputs a new condition set, where parallel layers are masked using the {"layer": []} syntax. It is still an experimental feature and might be buggy, but you can find it here: 8ec8e529f4a2b5150d90cb40e6aa10e6c5da5284 and use this code example:

from torchvision.models.resnet import resnet34
from crp.graph import trace_model_graph
from crp.helper import get_layer_names
import torch.nn as nn

model = resnet34().cuda()
model.eval()

layer_names = get_layer_names(model, [nn.Conv2d])
MG = trace_model_graph(model, (1, 3, 224, 224), layer_names)

conditions = [{"layer4.0.downsample.0" : [23], "y": [112]}, {"layer3.0.conv1": [55]}, {"layer3.0.conv1": [12], "layer4.0.downsample.0": [77]}]

new_conditions = MG.exclude_parallel_layers(conditions)
rachtibat commented 1 year ago

Issue fixed in #21 . Please read the new attribution tutorial for more details. The previous solution mentioned below is discarded as it is not optimal and introduces unnecessary overhead.

Hey,

we tried to implement a partial backward pass with PyTorch, but then we would have to limit the flexibility of the tool, which in the end I was not willing to do. That's why we settled with the following solution:

Because of the AttributionGraph in crp.graph, I already implemented a method to trace models. Now, I implemented a new helper function that takes in condition sets and outputs a new condition set, where parallel layers are masked using the {"layer": []} syntax. It is still an experimental feature and might be buggy, but you can find it here: 8ec8e52 and use this code example:

from torchvision.models.resnet import resnet34
from crp.graph import trace_model_graph
from crp.helper import get_layer_names
import torch.nn as nn

model = resnet34().cuda()
model.eval()

layer_names = get_layer_names(model, [nn.Conv2d])
MG = trace_model_graph(model, (1, 3, 224, 224), layer_names)

conditions = [{"layer4.0.downsample.0" : [23], "y": [112]}, {"layer3.0.conv1": [55]}, {"layer3.0.conv1": [12], "layer4.0.downsample.0": [77]}]

new_conditions = MG.exclude_parallel_layers(conditions)