chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Module with Multiple Inputs #176

Open rachtibat opened 1 year ago

rachtibat commented 1 year ago

Hey Chris,

hope you're well. I noticed an implementation detail where I am unsure if this was programmed on purpose and why.

At line you take in the backward pass only the first input, while saving previously all inputs in line.

I see that you defined the summation layer using a concat operation at line, so I assume restricting the inputs is on purpose.

So do you think, it is possible to attribute a summation layer defined in the following way in the future? And why did you restrict the input layers to have only one input?

class Sum(torch.nn.Module):

    def forward(self, input_1, input_2):
        return input_1 + input_2

Thanks a lot!

chr5tphr commented 1 year ago

Hey Reduan,

thanks for the issue as always.

Currently, I am indeed restricting Zennit to only attribute single inputs. I started out with single inputs as most layers that need to be attributed usually only have a single input, and for most cases there exists an equivalent module structure with only a single input (e.g. concatenated inputs). See for example here that the backward hook is also only attached to the first input.

I planned from the beginning to also support multiple inputs (and along the way, also parameters), and am working on getting this done in #168 , although I did not get to work on it recently. You can see here that I define multiple gradient_sinks, which can be attributed.

The current work in the PR to be done focuses more on the parameters, as it turned out somewhat tricky to reliably hook to Parameters (hooking to the tensor will always trigger when its gradient is computed, i.e. also at the wrong time, while creating a function to hook to is a little tricky as the parameter is not passed to a function but obtained as an attribute [which is probably where I will intercept]).

For the future, your proposed Sum module is intended to work, even with the BasicHook. If you are curios, you can see in the PR that the attribution will be computed differently for each specified sink e.g. here, although the way of addressing the sinks may change.

rachtibat commented 1 year ago

Hey,

awesome many thanks for the detailed explanation. I am very excited about the future development and will have a look at the PR to see if I can modify it for my purposes. Otherwise I noticed - and this is also a great strength of Zennit - that you can define Pytorch functions with a custom backward method (https://pytorch.org/docs/stable/notes/extending.html), which can be overwritten to compute a complex attribution method that might not yet be supported by Zennit Hooks and still perfectly integrate in the Zennit workflow.

Best