f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Second order computations for nn.Upsample #254

Open FrederikWarburg opened 2 years ago

FrederikWarburg commented 2 years ago

Hi

I need to compute the approximate hessian for a decoder network. The decoder consists of conv2d and upsample layers. Currently, backpack does not supports nn.Upsample. Since it is a non-parametric layer, it might not be too difficult to implement?

Here I define my model and a data point.

from backpack import backpack
from backpack.extensions import DiagGGNExact

model = torch.nn.Sequential(
    torch.nn.Conv2d(1,8, kernel_size=3, padding=1),
    torch.nn.MaxPool2d(2),
    torch.nn.ReLU(),
    torch.nn.Conv2d(8,8, kernel_size=3, padding=1),
    torch.nn.Upsample(scale_factor=2, mode="nearest"),
    torch.nn.ReLU(),
    torch.nn.Conv2d(8,1, kernel_size=3, padding=1),
    torch.nn.Flatten(),
)
lossfunc = torch.nn.MSELoss()

model = extend(model)
lossfunc = extend(lossfunc)

X = torch.zeros(1,1,8,8)
print(model(X).shape)

b = X.shape[0]
loss = lossfunc(model(X), X.view(b, -1))

with backpack(DiagGGNExact()):
    loss.backward()

for param in model.parameters():
    print(param.diag_ggn_exact)

will return this error

NotImplementedError: Extension saving to diag_ggn_exact does not have an extension for Module <class 'torch.nn.modules.upsampling.Upsample'>

Could you help implement this feature?

f-dangel commented 2 years ago

Hi, thanks for your feature request.

we have an example how to add new parameterized layers to first-order extensions. It's a good starting point. Since nn.Upsample has no parameters, you only have to implement how information for DiagGGNExact is backpropagated through the layer.

To do that, you would

It would be great if you gave it a shot and submitted a PR! I can provide more pointers to help.

Best, Felix