chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Densenet canonizations #171

Open gumityolcu opened 1 year ago

gumityolcu commented 1 year ago

Hello,

Here are a summary of the contributions:

  1. The epsilon values of the batch_norm layers used to be left as they are, when they need to be set to zero for perfect canonization. I fixed this. Note that the epsilon parameter can not be added to the batch_norm_params field of the canonizer. This is because it is a literal, not a torch Variable. So if you add the epsilon parameter there, the code will try to reach batch_norm.eps.data, which does not exits, when trying to restore it. Therefore, i use a new class variable "batch_norm_eps" to remember it during canonization.
  2. CompositeCanonizer now returns the list of handles reversed. This is because, if we have a two canonizers attaching to a module, then we need to detach them in the reverse order that they are applied, in order to restore the original values. I opted to reverse the list in the class because detaching the given handles in returned order seemed more user friendly. And I couldn't think of a use case where this would cause problems.
  3. MergeBatchNormtoRight canonizer is added. This merges a batch normalization layer to a linear layer that comes after it. If the linear layer is a convolutional layer with padding, this is not straightforward. A full feature map needs to be added to the output of the layer instead of a simple bias. This is done by adding forward hooks.
  4. ThreshReLUMergeBatchNorm is added. This canonizer detects BN->ReLU->Linear and changes the activation function to a function that depends on the batch norm variables to get the BatchNorm after the activation. Then the batchnorm is merged to the linear layer that is next to it. This is as described in https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf

Further more BN->ReLU->AvgPool->Linear chains are found and canonized using the same method, because Batch normalization commutes with average pooling. 6.Full proposed canonizers are added to torchvision.py. Another addition is DenseNetAdaptiveAvgPoolCanonizer which is needed before applying other canonizers to densenets. It makes the final ReLU and AvgPooling layers of torchvision densenet objects explicit. By default, these are applied in the forward method of the model, not as nn.module objects.

Thank you very much and I am looking forward to any kind of feedback!