f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

AdaptiveAvgPool not supported for 2nd order derivatives? #313

Closed hlzl closed 10 months ago

hlzl commented 10 months ago

I'm trying to use the current version of backpack-for-python==1.6.0 to compute Hessian and GGN diagonals and noticed that this seems not to be supported for torch.nn.AdaptiveAvgPool2d layers as I get the following error:

Extension saving to diag_h_batch does not have an extension for Module <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>

I could not find a disclaimer about this in the documentation, where adaptive pooling layers are stated as being supported. Am I missing something here?

Example to reproduce:

import torch
from backpack import backpack, extend
from backpack.extensions import BatchDiagGGNExact, BatchDiagHessian

model = extend(
    torch.nn.Sequential(
        *[
            torch.nn.Conv2d(3, 16, kernel_size=(3, 3)),
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(16, 2),
        ]
    ).cuda()
)
criterion = extend(torch.nn.CrossEntropyLoss())

batch = torch.randn((2, 3, 8, 8)).cuda()
target = torch.tensor([[1.0, 0.0], [0.0, 1.0]]).cuda()

model.zero_grad()
loss = criterion(model(batch), target)

with backpack(BatchDiagGGNExact(), BatchDiagHessian()):
    loss.backward()

hessian_diag = torch.cat(
    [p.diag_h_batch.view(batch.shape[0], -1) for p in model.parameters()], dim=1
)
ggn_diag = torch.cat(
    [p.diag_ggn_exact_batch.view(batch.shape[0], -1) for p in model.parameters()], dim=1
)
fKunstner commented 10 months ago

Thanks for the sample code!

Had a look, think there was a missing link between AdaptiveAvgPool*d and the BatchDiagHessian extension.

Should be fixed in on this branch if you want to install from source to give it a try.

f-dangel commented 10 months ago

Just merged the fix into the development branch.

hlzl commented 10 months ago

Seems to work for me as well. Thanks for the quick fix!

One more question if I might, I get a similar error for the diag_h_batch when adding BatchNorm2d to the previous test script as following and set the model to eval (as recommended). Calculating the diag_ggn_exact_batch works without problems using the batch norm.

model = extend(
    torch.nn.Sequential(
        *[
            torch.nn.Conv2d(3, 16, kernel_size=(3, 3)),
            torch.nn.BatchNorm2d(16),
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(16, 2),
        ]
    ).cuda()
)
model.eval()
...

NotImplementedError: Extension saving to diag_h_batch does not have an extension for Module <class 'torch.nn.modules.batchnorm.BatchNorm2d'>

Could this be due to a similar issue or is this related to #259?

Thank you!

hlzl commented 10 months ago

Opened a new issue with my follow up question #316 . The initial problem regarding the adaptive pooling for GGN and Hessian diagonal computation seems to be solved.