Open rachtibat opened 1 year ago
Hey Reduan,
thank you for the issue! You can have a look at this work, where they introduce LRP for Transformers (i.e. also attention heads). I have talked to @tschnake before about bringing transformers to Zennit, which is still as WIP as it gets.
The rearrange
operation is just a re-indexing, so the correct approach for it is already simply the gradient, so it is supported by Zennit.
The einsum
is a linear operation, so it can be handled like a linear layer in LRP.
The softmax
is a little tricky. In the work above they handle this by viewing the gating terms as constants.
In code, we may get away by requiring to use torch.nn.Softmax
and implementing a Constant
rule, which will have the gradient be set to zero, although I need to think a little more if this would work as intended.
Otherwise, we could also implement a canonizer (or a meta-rule) for the most popular library implementing attention layers.
Hi Christopher,
hope you're fine and I'm really glad that the zennit community grows, congratulation! With a growing community, more nn.Modules desire to be explained and that's why I'm writing this issue. A student in our department tries to explain a LinearAttention module. (The implementation is below for reference).
It contains a series of
torch.einsum
andtorch.transpose
operations.It uses the
rearrange
function of the einops library, a new syntax to write basic torch code like transpose, reshape etc.I think, zennit should be able to analyse a series of reshaping and transposing operations. However, I am not completely sure. I'd be glad, if you could give your opinion on analyzing such a linear attention module. If you don't know, that's also no problem (: Then, it's the beginning of a new research topic.
(And the softmax function is also a problem, but maybe Arras et. al has a solution to this which the student could implement... )
Best, Reduan