chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Obtanining graident of LRP otput w.r.t. network parameters #183

Closed MikiFER closed 1 year ago

MikiFER commented 1 year ago

Hi, first of all thank you for all the hard work that was put into developing this framework and then making it available to everyone. I was wondering if there is a way to obtain gradient of the explanation obtained using LRP with respect to the network parameters in order to optimize it. I stumbled across your overview paper and would like to use the framework in my own EGL research.

chr5tphr commented 1 year ago

Hey @MikiFER sorry for the delayed response. It is possible, although currently a little more involved (see below for a PoC) Also see this discussion. I have been working on supporting this in #168 at the end of last year, but unfortunately did not yet have the time to finalize the PR. If you are using VGG or something similar, it may work, but for ResNet, the current implementation has a few issues.

Otherwise, you can try to use this proof of concept I quickly put together:

Code ```python from itertools import islice import torch from torchvision.models import AlexNet from zennit.core import BasicHook, ParamMod from zennit.rules import Epsilon, Gamma, ZBox from zennit.composites import EpsilonGammaBox from zennit.attribution import Gradient from zennit.types import Convolution class ParamBasicHook(BasicHook): '''Hook to also get the relevance wrt. Parameters''' def backward(self, module, grad_input, grad_output): '''Backward hook to compute LRP based on the class attributes.''' original_input = self.stored_tensors['input'][0].clone() inputs = [] outputs = [] params = {key: [] for key, _ in module.named_parameters()} for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers): input = in_mod(original_input).requires_grad_() with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad(): # remember the gradient state grad_states = [param.requires_grad for param in modified.parameters()] # require the gradients to compute the relevance for param in modified.parameters(): param.requires_grad_() output = modified.forward(input) output = out_mod(output) # keep track of the params for later gradient computation for key, param in modified.named_parameters(): params[key].append(param) # reset the gradient state for param, grad_state in zip(modified.parameters(), grad_states): param.requires_grad = grad_state inputs.append(input) outputs.append(output) grad_outputs = self.gradient_mapper(grad_output[0], outputs) if isinstance(grad_outputs, torch.Tensor): grad_outputs = [grad_outputs] gradients = torch.autograd.grad( outputs * (1 + len(params)), inputs + sum(params.values(), []), grad_outputs=grad_outputs * (1 + len(params)), create_graph=grad_output[0].requires_grad ) grad_groups = [list(islice(elem, len(inputs))) for elem in [iter(gradients)] * (1 + len(params))] relevance = self.reducer(inputs, grad_groups[0]) # set the .grad of the original parameter for (key, param), gradient in zip(params.items(), grad_groups[1:]): getattr(module, key).grad = self.reducer(param, gradient) return tuple(relevance if original.shape == relevance.shape else None for original in grad_input) @classmethod def inject(cls, hook_type): '''Create a subclass of hook_type and this class, injecting this class before BasicHook in order to give this class' backward a higher priority. May also be done manually with e.g. ``class EpsilonParam(Epsilon, ParamBasicHook, BasicHook): pass``.''' return type(f'{hook_type.__name__}Param', (hook_type, cls, BasicHook), {}) ZBoxParam = ParamBasicHook.inject(ZBox) GammaParam = ParamBasicHook.inject(Gamma) EpsilonParam = ParamBasicHook.inject(Epsilon) def main(): torch.manual_seed(0xdeadbeef) net = AlexNet().eval() layer_map = [ (torch.nn.Linear, EpsilonParam(epsilon=1e-6)), (Convolution, GammaParam(gamma=0.25)), ] first_map = [ (Convolution, ZBoxParam(low=-3., high=3.)), ] composite = EpsilonGammaBox(low=-3., high=3., layer_map=layer_map, first_map=first_map) data = torch.randn((1, 3, 224, 224)) # not needed for LRP, only using this to compute the gradients for param in net.parameters(): param.requires_grad = True net(data).sum().backward() weight_grad = net.features[0].weight.grad[:] for param in net.parameters(): del param.grad param.requires_grad = False # compute LRP with Gradient(net, composite=composite) as attributor: out2, relevance = attributor(data) weight_relevance = net.features[0].weight.grad[:] # demonstrate that the gradient was modified print((weight_grad - weight_relevance).abs().sum()) if __name__ == '__main__': main() ```

Here, the trick is to create a new BasicHook which also computes the relevances of the parameters, create new Subclasses of existing hooks, injecting our new class via multiple inheritance, and then use those hooks whenever you would like to compute the relevance wrt. parameters. In the example, I have used Cooperative Layer Maps, but MixedComposite or any custom composite also works. Let me know if you have more questions, or in case there is something wrong with the PoC.

@MaxH1996 may also be interested in this code.

MikiFER commented 1 year ago

Hi @chr5tphr thank you for the response. I think that maybe we didn't understand each other. I do not require relevance of parameters I require gradient of the relevance map (in the input space) w.r.t. network parameters. For example lets say I want to compare relevance map obtained using LRP with some GT relevance map and I want to optimize network parameters in order to minimize some loss between them. On your documentation page I have found this page but I am not sure if that is what I need.

chr5tphr commented 1 year ago

Hey @MikiFER,

sorry for the confusion! This should indeed work without problem as described in the documentation. You should be able to use Tensor.backward() together with any Optimizer, as you would do when training with the unmodified gradients. Just make sure you call .backward outside the Attributor/Composite context, or within the Attributor.inactive() context as shown in the documentation.

Let me know in case you have issues with this, so we can try to figure it out together.

MikiFER commented 1 year ago

Thank you so much. Will try it and will get back to you if there are issues :)

MikiFER commented 1 year ago

Hi @chr5tphr by diving little deeper into the code I arrived to a question I cannot answer. I would like to use resnet architecture in my experiments. To obtain valid LRP explanations first a canonized version of the network must be obtained. In canonization batch norm is merged with linear layer that it is attached to by. If I were then to obtain an explanation and obtained gradient of the parameters w.r.t. to it would those gradients be accurate if I were then to de-canonize the network back to its original state. Here is the pseudo code of what I am trying to do.

composite = EpsilonPlusFlat(canonizers=canonizer)
for input, gt_value, gt_mask in dataset:
    model_out = model(input)
    classification_loss = Loss(model_out, gt_value)
    with composite.context(model) as canonized_model:
        explanation = model_out.backward(gradient=gt_value)
    explanation_loss = ExplanationLoss(explanation, gt_mask)
    combined_loss = classification_loss + explanation_loss
    combined_loss.backward()
    ...

I am afraid that calculating combined_loss.backward() will result in gradients of the canonized network but I want to optimize parameters of the "normal" network that is batch-norm and appropriate linear layer parameters will never be optimized.

Is there something that I am not understanding correctly?

chr5tphr commented 1 year ago

Hey @MikiFER,

theoretically, this should not be a problem, as the canonized parameters should be computed from the original parameters in such a way that the gradient is the same. However, I tested this and found out it is not behaving as expected, since the current implementation results in a detaching of the gradient. I looked into this and added #185 where the gradients seem to be computed correctly now.

You can check it out by directly installing with pip:

pip install git+https://github.com/chr5tphr/zennit.git@canonizer-merge-batchnorm-gradfix

Let me know whether it works for you.

Here's a proof of concept check ```python import torch from zennit.core import Composite from zennit.canonizers import SequentialMergeBatchNorm def main(): torch.manual_seed(0xdeadbeef) net = torch.nn.Sequential( torch.nn.Linear(32, 32), torch.nn.BatchNorm1d(32), ) weight = net[0].weight net.eval() net[1].running_mean += 1. net[1].running_var *= 3. canonizers = [ SequentialMergeBatchNorm() ] composite = Composite(canonizers=canonizers) data = torch.randn((1, 32)) weight.requires_grad = True out_base = net(data).sum() grad_base, = torch.autograd.grad(out_base, weight) with composite.context(net) as modified: out_canon = modified(data).sum() grad_canon, = torch.autograd.grad(out_canon, weight) print((out_base - out_canon).abs().sum()) print((grad_base - grad_canon).abs().sum()) if __name__ == '__main__': main() ```
MikiFER commented 1 year ago

Hi @chr5tphr thank you so much, I will try it out and get back to you if there are any more issues.

chr5tphr commented 1 year ago

Assuming there were no more issues, closing this for now after merging #185 . Feel free to reopen once something pops up.

MikiFER commented 11 months ago

Hi @chr5tphr I have a question regarding the obtained explanation using the ResNetCanonizer in combination with EpsilonPlusFlat composite. I noticed that sum of attributions for the input image is not 1 even though when using LRP with starting relevance for the output layer equal to 1 sum of relevance in all layers should be 1. Here is piece of code I used to replicate this behavior.

import torch
from torchvision.models import resnet18

from zennit.composites import  EpsilonPlusFlat
from zennit.torchvision import ResNetCanonizer

model = resnet18(weights=None)
canonizer = ResNetCanonizer()

composite = EpsilonPlusFlat(canonizers=[canonizer])

target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True
output = model(input_data)
with composite.context(model) as modified_model:
    attribution, = torch.autograd.grad(output, input_data, target)

print(attribution.shape, attribution.sum())

Am I not understanding something correctly or is this an error?

chr5tphr commented 11 months ago

Hey MikiFER,

usually, the attributions will not sum to one, unless you are certain that no attribution is lost to the bias, which you can do by passing zero_params='bias', e.g., in your case

composite = EpsilonPlusFlat(canonizers=[canonizer], zero_params='bias')

While investigating your issue, I noticed that, although #185 increased the overall attribution stability within ResNet, it lead to a negative attribution sum in the input (which can happen if some attribution is lost to biases in combination with skip-connections), for which I have opened #194. While at least for EpsilonGammaBox there is a quickfix, there does not seem to be a solution for EpsilonPlusFlat until I fixed the problem.

MikiFER commented 11 months ago

Hi @chr5tphr thanks for the response. I find it a little bit weird that almost 90% of attribution is lost to stability parameters (when inference is done before composite context) and I feel like there is something more to it. Also what I have noticed is that different attribution is obtained when model inference is done inside of the composite context and outside of it (before it). Is that the desired behavior? I believe the attribution should be the same because canonized model and original model should be equivalent.

Also one unrelated question. Have you tried paring up your library with pytorch-lightning? I get some weird results when trying to use half precision (fp16) training where model inference results in NaN result when inside composite context.

chr5tphr commented 11 months ago

Hey @MikiFER

it's not the stability parameters, but the bias term, which silently receives attribution. For example, in the Epsilon-Rule, we have

$R_i = \sum_j \frac{xi w{ji}}{\sum{i'}w{ji'}x_{i'} + b_j + \varepsilon} R_j$

where the denominator includes not only $\varepsilon$, but also the bias $b_j$. Since the biases are constant inputs to the network (or one could imagine a constant 1 in the input with another column for the bias in the weights), they will also receive relevance, which will result in a reduced relevance for the inputs.

This lost relevance can be omitted by removing the bias term from the denominator, which zero_params='bias' is for.

There is, however, as you also pointed out, currently something wrong with the changes introduced by #185, and my investigation so far points to the ResNet canonizer.

To have a better overview, feel free to create new issue when the topics are not directly related.