chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
183 stars 33 forks source link

Potential issue with learning BatchNorm parameters #204

Open MikiFER opened 7 months ago

MikiFER commented 7 months ago

Hi, in my project I have encountered and issue that I'm not sure if it's caused by invalid usage of the library or there is some bug in the library code. I cannot provide minimal code for reproduction because bug occurs during training so I will describe it as best as I can. Pseudo-code for my training looks something like:

for i in range(number_epochs):
    # train loop
    model.train()
    for input, gt in train_dataloader:
        with composite.context(model) as modified_model:
            model_out = modified_model(input)
            task_loss = get_task_loss(model_out, gt)
            gt_maps = torch.autograd.grad(model_out, input, torch.ones_like(labels), retain_graph=True)[0].sum(1)
            salience_loss = get_saliance_loss(model_out, gt)
            total_loss = task_loss + salience_loss

        total_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

    #validation loop
    model.eval()
    for input, gt in train_dataloader:
        with composite.context(model) as modified_model:
            model_out = modified_model(input)
            ...

For model I have tested VGG16, Resnet34 and VGG16_bn with appropriate canonizers, and for composite I have used EpsilonPlusFlat. All models have their heads changed to have 20 outputs, and are randomly initialized. I have noticed that models with BatchNorm have significant difference between output when in train mode and when in eval.

I have logged the sum of output during training to show this for different models.

For VGG16 we can see that output sums have around the same order of magnitude which is expected: image

For ResNet34 we see drastic change in output sums, around 4 orders of magnitudes difference image

For VGG16_bn we again see difference in output sums but difference is "only" around 1 order of magnitude: image

I see that this behaviour is very strange but it all points to something being wrong with BatchNorm. Version of Zennit I'm using is 0.5.2.dev5. I would really appreciate your help regarding this one. Thanks in advance.