Closed rachtibat closed 1 year ago
As far as I know zennit library requires modified forward function for ResNet to work. This is because it requires custom layer for summing skip value. I found your project very useful when explaining my ResNet model and all I needed to do to make explanations work was to create custom Canonizer as below. I believe the workaround to manually select names of layers which we want to skip while calculating gradient should solve the problem, but correct me if I'm wrong. Sharing my workaround below in case anyone finds it useful, the parameter for class init is the list of layers (blocks which have residual connection not conv layers), for which we don't want gradient to flow through skip-connection.
import torch
from zennit.canonizers import SequentialMergeBatchNorm, AttributeCanonizer, CompositeCanonizer
from zennit.torchvision import ResNetBottleneckCanonizer, ResNetBasicBlockCanonizer
from torchvision.models.resnet import Bottleneck as ResNetBottleneck
from zennit.layer import Sum
class HackBottleneckCanonizer(ResNetBottleneckCanonizer):
def __init__(self, overwrite_names):
AttributeCanonizer.__init__(self, self.get_attribute_map(overwrite_names))
@classmethod
def get_attribute_map(cls, overwrite_names):
def _attribute_map(name, module):
if isinstance(module, ResNetBottleneck):
if name in overwrite_names:
attributes = {
'forward': cls.forward_no_grad.__get__(module),
'canonizer_sum': Sum(),
}
return attributes
else:
attributes = {
'forward': cls.forward.__get__(module),
'canonizer_sum': Sum(),
}
return attributes
return None
return _attribute_map
@staticmethod
def forward_no_grad(self, x):
'''
Modified Bottleneck forward for HackResNet.
This forward doesn't propagate gradient through skip connections of given layers.
'''
identity = x.clone().detach()
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = torch.stack([identity, out], dim=-1)
out = self.canonizer_sum(out)
out = self.relu(out)
return out
class HackCanonizer(CompositeCanonizer):
def __init__(self, grad_omit_skips):
super().__init__((
SequentialMergeBatchNorm(),
HackBottleneckCanonizer(grad_omit_skips),
ResNetBasicBlockCanonizer(),
))
Hey,
I'm glad you were able to benefit from the project. Thanks for your effort and I like the creative idea - even if I haven't tested it yet. At the moment, I'm trying out another solution where the torch autograd graph is only partially differentiated, so that no additional overhead is generated for the users. But it will take a few more weeks and maybe it will never work.
There are two more simple methods to isolate the conditional heatmaps:
condition = [{"my_layer": [2], "parallel_layer": [], "y": [class]}]
This way, the gradient is set to zero in the parallel layers.
condition = [{"my_layer": [2], "y": [class]}]
attr = attribution(model, composite, condition)
condition = [{"my_layer": [], "y": [class]}]
attr2 = attribution(model, composite, condition)
final_heatmap = attr.heatmap - attr2.heatmap
This way, we remove the influence of the parallel layers.
Best wishes
Hey,
we tried to implement a partial backward pass with PyTorch, but then we would have to limit the flexibility of the tool, which in the end I was not willing to do. That's why we settled with the following solution:
Because of the AttributionGraph in crp.graph, I already implemented a method to trace models. Now, I implemented a new helper function that takes in condition sets and outputs a new condition set, where parallel layers are masked using the {"layer": []} syntax. It is still an experimental feature and might be buggy, but you can find it here: 8ec8e529f4a2b5150d90cb40e6aa10e6c5da5284 and use this code example:
from torchvision.models.resnet import resnet34
from crp.graph import trace_model_graph
from crp.helper import get_layer_names
import torch.nn as nn
model = resnet34().cuda()
model.eval()
layer_names = get_layer_names(model, [nn.Conv2d])
MG = trace_model_graph(model, (1, 3, 224, 224), layer_names)
conditions = [{"layer4.0.downsample.0" : [23], "y": [112]}, {"layer3.0.conv1": [55]}, {"layer3.0.conv1": [12], "layer4.0.downsample.0": [77]}]
new_conditions = MG.exclude_parallel_layers(conditions)
Issue fixed in #21 . Please read the new attribution tutorial for more details. The previous solution mentioned below is discarded as it is not optimal and introduces unnecessary overhead.
Hey,
we tried to implement a partial backward pass with PyTorch, but then we would have to limit the flexibility of the tool, which in the end I was not willing to do. That's why we settled with the following solution:
Because of the AttributionGraph in crp.graph, I already implemented a method to trace models. Now, I implemented a new helper function that takes in condition sets and outputs a new condition set, where parallel layers are masked using the {"layer": []} syntax. It is still an experimental feature and might be buggy, but you can find it here: 8ec8e52 and use this code example:
from torchvision.models.resnet import resnet34 from crp.graph import trace_model_graph from crp.helper import get_layer_names import torch.nn as nn model = resnet34().cuda() model.eval() layer_names = get_layer_names(model, [nn.Conv2d]) MG = trace_model_graph(model, (1, 3, 224, 224), layer_names) conditions = [{"layer4.0.downsample.0" : [23], "y": [112]}, {"layer3.0.conv1": [55]}, {"layer3.0.conv1": [12], "layer4.0.downsample.0": [77]}] new_conditions = MG.exclude_parallel_layers(conditions)
When defining for example the condition set [{"features.40": [0, 2]}] the channels 0 and 2 are passed through and all other channels are masked with zero. However, in models with several parallel connections (shortcuts), the parallel connections are not set to zero and relevance is passed through.
In future, all concepts in parallel layers should also be set to zero, to get the sole contribution of the masked concept.