pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.8k stars 485 forks source link

Undesirable behavior of LayerActivation in networks with inplace ReLUs #156

Closed mrsalehi closed 4 years ago

mrsalehi commented 4 years ago

Hi, I was trying to use captum.attr._core.layer_activation.LayerActivation to get the activation of the first convolutional layer in a simple model. Here is my code:

torch.manual_seed(23)
np.random.seed(23)
model = nn.Sequential(nn.Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True),
                      nn.Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True))

layer_act = LayerActivation(model, model[0])
input = torch.randn(1, 3, 5, 5)
mylayer = model[0]
print(torch.norm(mylayer(input) - layer_act.attribute(input), p=2))

In fact, I have computed the activation in two different ways and compared them afterwards. Obviously, I expected a value close to zero to be printed as the output, however, this is what I got:

tensor(3.4646, grad_fn=<NormBackward0>)

I hypothesize that the inplace ReLU layer after the convolutional layer acts on its output since there were many zeros in the activation computed by Captum ( i.e. layer_act.attribute(input)). In fact, when I changed the architecture of the network to the following:

model = nn.Sequential(nn.Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(),
                      nn.Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.ReLU(inplace=True))

then the outputs matched.

System information

vivekmig commented 4 years ago

Hi @mrsalehi, yes, this is a bug, thanks for pointing it out! We will push a fix for this soon.

vivekmig commented 4 years ago

Fix has been merged here: https://github.com/pytorch/captum/commit/5bf06ba94c3ea1c992b5cc2b29daa06fae527d34