chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Smooth MaxPool2D rule #181

Open rachtibat opened 1 year ago

rachtibat commented 1 year ago

Hey,

we'd like to add a new rule that smooths the MaxPool2D operation by replacing it by an AveragePool2D backward pass:

class SmoothMaxPool2dRule(BasicHook):

    def __init__(self, epsilon=1e-6, zero_params=None):
        stabilizer_fn = Stabilizer.ensure(epsilon)
        super().__init__(
            gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])),
            reducer=(lambda inputs, gradients: inputs[0] * gradients[0]),
        )

    def backward(self, module, grad_input, grad_output):
        '''Backward hook to compute LRP based on the class attributes.'''
        original_input = self.stored_tensors['input'][0].clone()
        inputs, outputs = [], []
        kernel_size = module.kernel_size
        stride = module.stride
        padding = module.padding

        input = original_input.requires_grad_()
        with torch.autograd.enable_grad():
            output = F.avg_pool2d(input, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None)
        inputs.append(input)
        outputs.append(output)

        grad_outputs = self.gradient_mapper(grad_output[0], outputs)
        gradients = torch.autograd.grad(
            outputs,
            inputs,
            grad_outputs=grad_outputs,
            create_graph=grad_output[0].requires_grad
        )
        relevance = self.reducer(inputs, gradients)
        return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)

You can test the code with

import torch.nn as nn
from zennit.rules import *
from zennit.core import BasicHook
import torch.nn.functional as F

if __name__ == "__main__":

    input = torch.linspace(0, 35, 36).view(1, 1, 6, 6).requires_grad_()

    layer = nn.MaxPool2d(2, 2, 0)
    norm_rule = Norm()
    h = norm_rule.register(layer)

    output = layer(input)
    grad, = torch.autograd.grad(output, input, torch.ones_like(output))
    h.remove()

    print(input)
    print(output)
    print(grad)

    print("###")

    rule = SmoothMaxPool2dRule()
    h = rule.register(layer)

    output = layer(input)
    grad, = torch.autograd.grad(output, input, torch.ones_like(output))
    h.remove()

    print(input)
    print(output)
    print(grad)

Do you think that's fine? I can create a pull request if you want.

Best, Reduan

chr5tphr commented 1 year ago

Hey Reduan,

thanks for the issue as always!

I think having a way to use the AvgPool2d gradient for MaxPool2d layers is a must-have. I have some proof-of-concept code which I implemented back in the day to directly and explicitly compute the avg-pool gradient with MaxPool parameters using transposed convolutions.

While going over your code and seeing the BasicHook.backward structure copied, I had the idea that we could also add a layer of abstraction above ParamMod: a ModuleMod or FuncMod, which is a general modifier of the forward function. This way, one could add very flexible custom rules based on BasicHook, not only limited to the parameters of the module, which would be especially useful for parameter-less modules like MaxPool.

I have a different approach of attributing MaxPool in the pipeline, which could benefit from this approach. Do you maybe know of another use-case for arbitrary function override? Or maybe @sebastian-lapuschkin ?

If it is only for MaxPool, implementing an explicit rule based on Hook may be better, where we could instead use my existing proof-of-concept code. Although, and I guess that's why you based this off BasicHook rather than Hook, stabilizer would not automatically be part of the rule, which I think may not be necessary for pooling anyway.

As for the name, maybe its better to call it something like AvgPoolRule, since for AvgPool this would also be correct, although one could just use the EpsilonRule there.

rachtibat commented 1 year ago

Hey,

thank you for your prompt and thoughtful response as always. I like the idea to add a FuncMod.

I ask Sebastian, and he told me that another use-case would be to change the 1x1 CNN downsample layer with stride=2 in ResNets that also creates such a checkerboard pattern. See: image

The question is, if we should implement it with a FuncMod.

A spontaneous idea that would change the backward pass function instead:

  1. Compute Relevance normally. As a result, only every second or forth pixel would get relevance, the others are zero.
  2. Take only the relevance pixels and write them in a smaller image with 1/4 size.
  3. Average Upsample the image to the original size

With a FuncMod we could do:

  1. Take the pixels that would be selected by the downsample layer and repeat them 4 times to the original input size by overwriting the ignored pixels.
  2. Do a 2x2 downsample with 4 times bigger kernel but 1/4th kernel values

Best