f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Parameter `grad`s don't get initialized with `BatchL2Grad` and BatchNorm #239

Open thomasahle opened 2 years ago

thomasahle commented 2 years ago

BatchL2Grad, perhaps naturally, raises an error when it sees a BatchNorm, since batch normalization mixes gradients in a way that makes the individual contribution hard to discern. The error says I can ignore it, if I know what I'm doing. I can't say I completely do, but if I ignore it, I do indeed get both grads and batch_l2s on the top levels of my mode, which aren't using batch-norm. I'm happy with that.

My problem is that the lower level parameters - which do use batch norm - don't just have a None batch_l2, but also a None grad. So my model doesn't train at all. This seems wrong, since grad is indeed computable, as witnessed by PyTorch being able to do so fine without backpack.

Is there a way I can get batch_l2s on as many of my parameters as possible, but grads on everything?

I an do this now by first calling backward() without backpack, and then calling it again inside with backpack(BatchL2Grad()):, but that seems wasteful.

f-dangel commented 2 years ago

Hi @thomasahle,

I'm not sure I fully understand your question. Individual gradients don't exist in a network with BN layer in train mode, because the individual losses depend on all samples in the mini-batch. The purpose of the warning you're seeing is exactly to point out this caveat. Are you sure the batch_l2 you're getting for the other parameters would not correspond to individual gradient l2 norms, as individual gradients don't exist with BN.

Best, Felix

thomasahle commented 2 years ago

Hi Felix, I'm indeed interested in individual gradient l2 norms. If I can't get them on all my parameters, that's fine.

But it would be nice if the normal non batched grad would still be computed for my parameters, without having to run backprop again without backpack.

Right now the non batched grad is only computed whenever batch_l2 is computed.

f-dangel commented 2 years ago

Hi Thomas,

Right now the non batched grad is only computed whenever batch_l2 is computed.

That seems odd to me because BackPACK does not intervene into PyTorch's gradient computation. Are you sure that these parameters have requires_grad = True?

I'm indeed interested in individual gradient l2 norms. If I can't get them on all my parameters, that's fine.

For the loss of a neural network with batch normalization, individual gradients, and hence their l2 norm, don't exist. BackPACK only detects this when it encounters a batch norm module. So the result in batch_l2 in the layers before is not a per-sample gradient l2 norm. Maybe this post I wrote is helpful to understand this in more detail.

thomasahle commented 2 years ago

That seems odd to me because BackPACK does not intervene into PyTorch's gradient computation. Are you sure that these parameters have requires_grad = True?

Yes, it is only the parameters that BatchL2Grad does not support that don't get grad. The other parameters get both batch_l2 and grad.

I was thinking this might be a matter of how the exception is handled? That when the NotImplementedError is thrown, computation somehow gets aborted and grad isn't computed as it otherwise would have been.

thomasahle commented 2 years ago

Here is example code of what I mean:

import torch
import torch.nn as nn
import backpack
from backpack.extensions import BatchL2Grad

channels = 5
data = torch.randn(100, channels, 10, 10)
labels = torch.randn(100, 5)

model = nn.Sequential()
model.add_module('conv', nn.Conv2d(channels, 5, kernel_size=3, stride=1, padding=1, bias=False))
model.add_module('batch norm', nn.BatchNorm2d(5))
model.add_module('flat', nn.Flatten(1))
model.add_module('linear', nn.Linear(500, 5))

model = backpack.extend(model)

y = model(data)
loss = torch.sum((y - labels)**2)

with backpack.backpack(BatchL2Grad()):
    try:
        loss.backward()
    except NotImplementedError as e:
        pass

for name, param in model.named_parameters():
    if not hasattr(param, 'batch_l2'):
        print(f'Param {name} has no batch_l2')
    if not hasattr(param, 'grad') or param.grad is None:
        print(f'Param {name} has no grad')

This outputs:

Param conv.weight has no batch_l2
Param conv.weight has no grad
Param batch norm.weight has no batch_l2
Param batch norm.bias has no batch_l2

In other words, the linear layer gets both batch_l2 and grad, which is great. The conv layer doesn't get batch_l2, since it is below the batch norm, which makes sense from what you wrote. However, I don't see why conv couldn't get grad, just like it would without BatchL2.

fKunstner commented 2 years ago

Hi Thomas,

I think the problem is that the error message is not strong enough. It should be

"Encountered BatchNorm module in training mode. Quantity to compute is undefined."

The batch_l2 of the elements at the end of the network (after the batchnorm) are getting filled in because we only realise that the computation is meaningless when we hit the batchnorm layer, going backwards. It is not the L2 norm of the individual gradients, even for those parameters. If your work involves gradients of individual samples, you should avoid batchnorm.

If you specifically want to look at what the quantity that would be obtained by applying the same code used to get individual gradients but in a batchnorm network, you can install from source (pip install -e backpack once you've extracted the source code) and remove the exception.

thomasahle commented 2 years ago

I guess you are right about batch_l2 not being defined. I could fall back, and just compute the grad with normal pytorch instead in this case, but it still seems like BatchL2Grad might as well do it, even if batch_l2 makes no sense, just because BatchL2Grad normally computes grad, and grad is defined.