fhvilshoj / TorchLRP

A PyTorch 1.6 implementation of Layer-Wise Relevance Propagation (LRP).
MIT License
124 stars 25 forks source link

Apply rho and incr for weights in forward pharse #1

Closed linhlpv closed 3 years ago

linhlpv commented 3 years ago

HI. Thank you for your great work ! I have a question that Do you apply rho and incr function for weights of layer while calculate in forward ? Thank you so much !

fhvilshoj commented 3 years ago

No, they are meant for the backward pass. In particular, they are used to enable the different LRP rules such as the epsilon rule or gamma rule. See an example in the code here. For a more detailed description, see Section 1.2 here.

linhlpv commented 3 years ago

Yeah. Thank you for your sugesstion. I have read the tutorial of heatmapping and the paper Layer-wise relevance propagation: An overview. Both of them use this pseudo code and code like that image

. In that code, they use incr an rho func to the layer in forward ( so the backward too). I'm confusing so much.

linhlpv commented 3 years ago

Sorry about the no image. Now I have weak wifi 😭

fhvilshoj commented 3 years ago

Well yeah, they use it to recompute a "forward pass" when they do the back propagation of the relevances. But it is not affecting the output of the forward pass of the network; only the backward pass. The reason is relevance concervation which means that R.sum() should be the same for the input and the output of the relevance propagation function (the image you sent).

Same thing applies for this library. i) Forward pass of the network is not altered. here in the code ii) Backward pass computes a new "forward" including rho and incr to conserve relevance. here in the code

linhlpv commented 3 years ago

Thank you so much. I have read again your code and now I'm clear about it.

linhlpv commented 3 years ago

I have one more question about implementation of BatchNorm layer. Could you give me some ideas about its implementation ? Once again, thank u so much about your answers and great work !

fhvilshoj commented 3 years ago

As indicated here (by BatchNorms absence), there is nothing specific behaviour implemented for BatchNorm. Whenever that is the case, standard gradient computations are used for backpropagating relevance. This behaviour is not entirely correct in the sense that the relevance concervation property is not preserved. However, the resulting heatmaps (after normalization) will still be the same.

If you want to implement it, here is some inspiration.

linhlpv commented 3 years ago

Yeah! Thank you ! I will try to implement BatchnormLRP like your suggestion.

linhlpv commented 3 years ago

And I have more question about relevance calculation> I see that when you calculate gradient of input, you use Linear layer with weight is transposed (in case Linear layer) and ConvTranspose layer (in case Conv layer). I just a bit curious that why you dont use Z backward and use input.grad? Just want to know more about your code xD. Thank you so much and have a nice day !

fhvilshoj commented 3 years ago

The short answer is that you are not allowed to mix autograd function definitions with the backward call.

When inheriting from torch.autograd.Function, you are defining new functions and their "local derivatives," which are used in autograd to do automatic differentiation, mainly by applying the chain-rule all the time. Specifically, you follow this pattern:

from torch.autograd import Function

class MyFunction(Function):
    @staticmethod
    def forward(ctx, input, ...your params ):
        # 1. compute output
        out = ...
        # 2. Store information needed for backward pass
        ctx.save_for_backward(input, weight, bias) # Tensors
        ctx.rho = rho # Constants, bools, etc.
        return out

    @staticmethod
    def backward(ctx, Grad):
        # 1. compute dy / dx
        dydx = ... # Use ctx.rho or other stores variables if necessary
        return dydx

So you define a static class which has a forward pass and an associated backward pass (gradient computation for gradient backpropagation or in our case relevance propagation). The backward needs to be specified as a computation and cannot, in this case, be determined by autograd, because autograd relies exactly on the backward function to compute gradients.