chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

LinearAttention Module #169

Open rachtibat opened 1 year ago

rachtibat commented 1 year ago

Hi Christopher,

hope you're fine and I'm really glad that the zennit community grows, congratulation! With a growing community, more nn.Modules desire to be explained and that's why I'm writing this issue. A student in our department tries to explain a LinearAttention module. (The implementation is below for reference).

It contains a series of torch.einsum and torch.transpose operations.

It uses the rearrange function of the einops library, a new syntax to write basic torch code like transpose, reshape etc.

I think, zennit should be able to analyse a series of reshaping and transposing operations. However, I am not completely sure. I'd be glad, if you could give your opinion on analyzing such a linear attention module. If you don't know, that's also no problem (: Then, it's the beginning of a new research topic.

(And the softmax function is also a problem, but maybe Arras et. al has a solution to this which the student could implement... )

Best, Reduan

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)
chr5tphr commented 1 year ago

Hey Reduan,

thank you for the issue! You can have a look at this work, where they introduce LRP for Transformers (i.e. also attention heads). I have talked to @tschnake before about bringing transformers to Zennit, which is still as WIP as it gets.

About the implementation details:

The rearrange operation is just a re-indexing, so the correct approach for it is already simply the gradient, so it is supported by Zennit. The einsum is a linear operation, so it can be handled like a linear layer in LRP. The softmax is a little tricky. In the work above they handle this by viewing the gating terms as constants.

In code, we may get away by requiring to use torch.nn.Softmax and implementing a Constant rule, which will have the gradient be set to zero, although I need to think a little more if this would work as intended.

Otherwise, we could also implement a canonizer (or a meta-rule) for the most popular library implementing attention layers.