chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Layer-wise LRP score #167

Open shivmgg opened 1 year ago

shivmgg commented 1 year ago

Hi,

Thanks a lot for the awesome work! May I know how I can extract layer-wise LRP attribution scores for ResNet-18?

chr5tphr commented 1 year ago

Hey Shivam,

check out out the documentation, there is a section on how-to extract the attribution scores per layer.

shivmgg commented 1 year ago

Thanks Christopher for the documentation. I can now obtain activation gradients for each intermediate layer.

Is it possible to get relevance scores for weight matrices in each layer in a similar fashion? I want to quantify the relevance/importance of each weight value.

chr5tphr commented 1 year ago

Hey Shivam,

turns out it is currently a little more involved to do in Zennit. In this paper, they directly modified zennit/core.py to also compute the modified gradients wrt. the parameters correctly.

Since yesterday I have looked a bit deeper into how this could be done in a more general way in Zennit, and decided to change the Hook implementation in such a way that the attribution scores for multiple inputs and parameters can be computed. I have added #168 which is still a rough draft, although computing the attribution scores of parameters seems to be working (although I have only verified the correctness for the 1-layer dense case).

Feel free to play around with the implementation and report any problems that you get. Here's an example of how to compute the attribution scores for all parameters (by default, they require a gradient, so using .backward will store the attribution scores in e.g. model.features[0].weight.grad, but note that subsequent calls accumulate into the .grad).

import torch
import zennit
from torchvision.models import vgg11

# model = torch.nn.Linear(4, 4).eval()
model = vgg11().eval()
data = torch.randn(1, 3, 224, 224, requires_grad=True)

# composite = zennit.composites.NameMapComposite([([''], zennit.rules.Epsilon())])
composite = zennit.composites.EpsilonPlusFlat()

with composite.context(model):
    out = model(data)
    # relevance, = torch.autograd.grad(out, data, torch.ones_like(out))
    out.sum().backward()

print(model.features[0].weight.grad)

Edit: You can try it out directly by installing from github using pip:

$ pip install git+https://github.com/chr5tphr/zennit.git@hook-multi-input-param
MaxH1996 commented 1 year ago

Hi @chr5tphr , thank you for providing this feature. I have been working with this version of Zennit to get relevance scores for w.r.t parameters. However, when using a ResNet20 architecture, I encounter the following error which I cannot really explain:

result = hook.backward(module, grad_output, hook.stored_tensors['grad_output'], grad_sink=grad_sink) TypeError: Pass.backward() got an unexpected keyword argument 'grad_sink'

originating from zennit/core.py in wrapper. This happens when I call out.sum().backward(). Do you by any chance know why this error appears?

chr5tphr commented 1 year ago

Hey @MaxH1996

I noticed I forgot to add the grad_sink parameter for some of the rules that are directly based on Hook. (I am thinking doing something else and not using grad_sink, as this changes the interface of Hook by introducing a new mandatory argument.) I fixed this, but unfortunately I noticed a few more involved problems with ResNet for this implementation and I will not be able to work on it for at least a week.

MaxH1996 commented 1 year ago

Hey @chr5tphr,

thank you very much for you response, and your quick fix of grad_sink.

Too bad about ResNet, but thanks for taking the time to work on it. Could you perhaps let me know when you fixed this particular issue? That would help me a lot.