f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Missing implementation of supported layers for DiagHessian and BatchDiagHessian #316

Open hlzl opened 10 months ago

hlzl commented 10 months ago

There are multiple layers which are specified as being supported for second order derivatives that actually do not work when trying to calculate the Hessian diagonal using backpack-for-pytorch<=1.6.0.

So far, I've run into this problem with the following layers:

This can be tested with a script such as the following:

import torch
from backpack import backpack, extend
from backpack.extensions import DiagHessian, BatchDiagHessian
from backpack.custom_module.branching import Parallel, SumModule

model = extend(
    torch.nn.Sequential(
        *[
            torch.nn.Conv2d(3, 16, kernel_size=(3, 3)),
            Parallel(
                torch.nn.Identity(), torch.nn.BatchNorm2d(16), merge_module=SumModule()
            ),
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(16, 2),
        ]
    ).cuda()
)
criterion = extend(torch.nn.CrossEntropyLoss())

batch = torch.randn((2, 3, 8, 8)).cuda()
target = torch.tensor([[1.0, 0.0], [0.0, 1.0]]).cuda()

model.eval()
model.zero_grad()
loss = criterion(model(batch), target)

with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()

hessian_diag = torch.cat(
    [p.diag_h.view(-1) for p in model.parameters()], dim=1
)
hessian_diag_batch = torch.cat(
    [p.diag_h_batch.view(batch.shape[0], -1) for p in model.parameters()], dim=1
)

I'm guessing that these require independent fixes, but think it is a good idea to collect all layers with missing support summarised here.