f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

fix: ScaleModule and SumModule for DiagHessian. #317

Open hlzl opened 10 months ago

hlzl commented 10 months ago

Partially fixes #316 . ScaleModule is also used for torch.nn.Identity.

Not sure if hessian_is_zero() should always return True for those two modules. Same with accumulate_backpropagated_quantities() which concats dicts instead of tensors as for the DiagGGN.

hlzl commented 10 months ago

Commit 74c41735a12b72861e17e5ed4c2b0a97d40283c0 allows to compute the Hessian diagonal even if there is a batch norm in the network by simply not computing the Hessian elements for the batch norm layer.

Not sure if this is a reasonable approach, however, this can be used as a quick fix.

The other diagonal elements can then be extracted as following:

hessian_diag_wo_bn = torch.cat(
    [
        p.diag_h_batch.view(batch.shape[0], -1)
        for p in model.parameters()
        if "diag_h_batch" in p.__dict__.keys()
    ],
    dim=1,
)
hlzl commented 10 months ago

Commit 48f03e92e12fd97970a78335fb645ac5ee9a77f2 tries to actually compute the diagonal elements of the Hessian for the batch norm layer.

If one of you could have a quick look at the commits to see if they make any sense, would really appreciate it. @f-dangel @fKunstner

Thank you!

f-dangel commented 10 months ago

Hi,

just wanted to let you know I read your message above. Please don't expect any reaction before the ICLR deadline (Sep 28)