chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
200 stars 32 forks source link

How to add support to DenseNet from torchvision #165

Open mstaczek opened 2 years ago

mstaczek commented 2 years ago

I wanted to use LRP with DesneNet121 from torchvision. So far, in zennit.torchvision I found canonizers for ResNet and VGG and I wonder if I may use them (to get some results) or I need to write my own custom canonizer (because the network has some new layers that were not covered by previous canonizers?).

Thanks for your help!

chr5tphr commented 2 years ago

Hey @mstaczek

the truth is, DenseNet is quite tricky with LRP in the sense that the DenseLayers within the DenseBlocks end with a linear layer without an activation, and start with a BatchNorm. This means that, with the residual connections, multiple BatchNorm layers are connected to multiple Linear layers, so they cannot be merged into the linear layer as it is normally done.

Image from the paper:

Expand image ![image](https://user-images.githubusercontent.com/15217558/192507405-60f82150-e115-490f-bf84-8ae2032c1f30.png)

Text representation of the torchvision model:

Expand text ``` DenseNet( (features): Sequential( (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu0): ReLU(inplace=True) (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (denseblock1): _DenseBlock( (denselayer1): _DenseLayer( (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): _DenseLayer( (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): _DenseLayer( (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): _DenseLayer( (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): _DenseLayer( (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): _DenseLayer( (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition1): _Transition( (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock2): _DenseBlock( (denselayer1): _DenseLayer( (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): _DenseLayer( (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): _DenseLayer( (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): _DenseLayer( (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): _DenseLayer( (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): _DenseLayer( (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): _DenseLayer( (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): _DenseLayer( (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): _DenseLayer( (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): _DenseLayer( (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): _DenseLayer( (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): _DenseLayer( (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition2): _Transition( (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock3): _DenseBlock( (denselayer1): _DenseLayer( (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): _DenseLayer( (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): _DenseLayer( (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): _DenseLayer( (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): _DenseLayer( (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): _DenseLayer( (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): _DenseLayer( (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): _DenseLayer( (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): _DenseLayer( (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): _DenseLayer( (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): _DenseLayer( (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): _DenseLayer( (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer13): _DenseLayer( (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer14): _DenseLayer( (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer15): _DenseLayer( (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer16): _DenseLayer( (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer17): _DenseLayer( (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer18): _DenseLayer( (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer19): _DenseLayer( (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer20): _DenseLayer( (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer21): _DenseLayer( (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer22): _DenseLayer( (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer23): _DenseLayer( (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer24): _DenseLayer( (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (transition3): _Transition( (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) ) (denseblock4): _DenseBlock( (denselayer1): _DenseLayer( (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer2): _DenseLayer( (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer3): _DenseLayer( (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer4): _DenseLayer( (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer5): _DenseLayer( (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer6): _DenseLayer( (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer7): _DenseLayer( (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer8): _DenseLayer( (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer9): _DenseLayer( (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer10): _DenseLayer( (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer11): _DenseLayer( (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer12): _DenseLayer( (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer13): _DenseLayer( (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer14): _DenseLayer( (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer15): _DenseLayer( (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) (denselayer16): _DenseLayer( (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu1): ReLU(inplace=True) (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) ) ) (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (classifier): Linear(in_features=1024, out_features=1000, bias=True) ) ```

Current State

Using DenseNet without canonizers does not work correctly:

Heatmaps without canonizers ![image](https://user-images.githubusercontent.com/15217558/192508686-a07f4915-bc69-4b7d-8f48-f128e2ad85d9.png)

You can use the Epsilon rule in BatchNorm layers for slightly better results:

Heatmaps without canonizers, with Epsilon in BatchNorm ![image](https://user-images.githubusercontent.com/15217558/192508834-14232890-a901-4bcd-a264-d3c9794d986d.png)

Here's some code to produce heatmaps with densenet121:

Code ```python import os import torch from torchvision.models import densenet121 from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop from PIL import Image from zennit.attribution import Gradient from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox from zennit.types import BatchNorm from zennit.image import imgify from zennit.rules import Epsilon fname = 'dornbusch-lighthouse.jpg' if not os.path.exists(fname): torch.hub.download_url_to_file( 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg', fname, ) # define the base image transform transform_img = Compose([ Resize(256), CenterCrop(224), ]) # define the normalization transform transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # define the full tensor transform transform = Compose([ transform_img, ToTensor(), transform_norm, ]) # load the image image = Image.open('dornbusch-lighthouse.jpg') # transform the PIL image and insert a batch-dimension data = transform(image)[None] model = densenet121(weights='DEFAULT').eval() composite = EpsilonGammaBox(low=-3., high=3., layer_map=[(BatchNorm, Epsilon())]) input = data.clone().requires_grad_(True) target = torch.eye(1000)[[437]] with Gradient(model, composite) as attributor: output, relevance = attributor(input, target) transform_img(image).save('original.png') imgify(relevance[0].detach().sum(0), cmap='bwr', symmetric=True).save('densenet121.png') ```

Implementing a DenseNet Canonizer

Ultimately, a canonizer needs to be implemented also for DenseNet, due to its problematic BatchNorms. We cannot merge the BatchNorms into the adjacent linear layer, since multiple BatchNorms use the same linear layer, and we cannot merge the adjacent linear layer into the BatchNorms, since the BatchNorm is not expressive enough.

There are a few settings of BatchNorms that need different handling:

There may be things that I overlooked, but LRP for DenseNet, is, by the design of LRP, currently quite a challenge to get right. It needs careful thinking in order to be done as implied by the definition of LRP.

I will try to discuss this in our Lab and see if there's a better solution, but maybe until then you can try the Epsilon rule for BatchNorm layers.

mstaczek commented 2 years ago

Wow, I did not expect it to be such a challenge!

Thank you for the explanation and sample heatmaps. They really help to convince that a custom canonizer is necessary for DenseNets. I will think about implementing it after reading more about DenseNet, it's blocks and LRP.