chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
188 stars 32 forks source link

Bug in AlphaBeta rule? #107

Closed nkoenen closed 2 years ago

nkoenen commented 2 years ago

Hi,

I've been working very extensively with the LRP method lately and I also tried to implement the method with the most commonly used rules by myself. In order to check the correctness of my implementation, I compared some results with already existing implementations (like yours ;) ). Thereby I always get different relevances with the AlphaBeta-rule (I already opened an issue in innvestigate with the same example). Maybe you can explain the following behavior or confirm that this is really a bug on your side:

Let's suppose we have only one layer with two inputs, one output and no activation function. The layer has the following weights and bias vector: W = (1, -1) and b = -1. For the input x = (1,1), the formula of the AlphaBeta rule (Eq. (60)) in Bach et al. reduces to (in this case is r_out = -1)

This yields a relevance of (-1, 0) for the Alpha1_Beta0-rule and (-2, 0.5) for the Alpha2_Beta1-rule. But with your implementation I get both times (0.5, -0.5). Also for other choices of alpha and beta I always get the same result. Here is my code snippet for the Alpha1_Beta0 rule:

import torch
import torch.nn as nn
from zennit.rules import AlphaBeta

input = torch.tensor([[1.,1.]], requires_grad = True)

model = nn.Sequential(
      nn.Linear(2, 1)
  )
model.get_submodule("0").weight.data = torch.tensor([[1., -1.]])
model.get_submodule("0").bias.data = torch.tensor([-1.])

rule = AlphaBeta(alpha = 1, beta = 0)
rule.register(model)
output = model(input)

grad_out = torch.ones_like(output) * output

attr, = torch.autograd.grad(
    output, input, grad_outputs=grad_out
)

attr
# tensor([[ 0.5000, -0.5000]])

I hope you can help me to clarify this behavior.

Best Niklas

chr5tphr commented 2 years ago

Hey Niklas,

the problem is that you try to register the rule to the Sequential model itself, instead of the Linear layer, which will not work. I'm actually surprised this does not crash.

Here's two ways to do it, one with only a single layer:

Single Linear Code ```python import torch from torch.nn import Linear from zennit.rules import AlphaBeta layer = Linear(2, 1) layer.weight.data = torch.tensor([[1., -1.]]) layer.bias.data = torch.tensor([-1.]) input = torch.tensor([[1., 1.]], requires_grad=True) print('Layer-only:') for alpha, beta in [(1., 0.), (2., 1.)]: print(f' alpha={alpha}, beta={beta}') # create hook and immediately register to layer handle = AlphaBeta(alpha=alpha, beta=beta).register(layer) output = layer(input) relevance, = torch.autograd.grad(output, input, grad_outputs=output) # remove the hook from the layer handle.remove() print(f' {relevance}') ```

which prints:

Single Linear Output ``` Layer-only: alpha=1.0, beta=0.0 tensor([[-1.0000, 0.0000]]) alpha=2.0, beta=1.0 tensor([[-2.0000, 0.5000]]) ```

and one with the full Sequential, using a custom Composite and Attributor:

Sequential Code ```python import torch from torch.nn import Linear, Sequential from zennit.rules import AlphaBeta from zennit.attribution import Gradient from zennit.composites import LayerMapComposite layer = Linear(2, 1) layer.weight.data = torch.tensor([[1., -1.]]) layer.bias.data = torch.tensor([-1.]) # create a simple Sequential model with a single layer model = Sequential(layer) input = torch.tensor([[1., 1.]], requires_grad=True) print('Custom Composite:') for alpha, beta in [(1., 0.), (2., 1.)]: print(f' alpha={alpha}, beta={beta}') # create a custom composite, which maps Linear layers to AlphaBeta composite = LayerMapComposite([((Linear,), AlphaBeta(alpha=alpha, beta=beta))]) # use the Gradient attributor on the model with our custom composite with Gradient(model, composite) as attributor: out, relevance = attributor(input) print(f' {relevance}') ```

which prints:

Sequential Output ``` Custom Composite: alpha=1.0, beta=0.0 tensor([[-1.0000, 0.0000]]) alpha=2.0, beta=1.0 tensor([[-2.0000, 0.5000]]) ```

which is both what you computed by hand.

Have a look at the documentation, where there is also an example with only a single layer.

nkoenen commented 2 years ago

Awesome! Thank you for the quick reply.