aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Laplace Subnetwork with timm library model #128

Open nilsleh opened 1 year ago

nilsleh commented 1 year ago

Hi,

I would like to apply the Laplace Subnetwork approach to a timm library model (standard resnet18). I think the problem I am encountering is not unique to timm models per se, but to inplace operations? I have made a small reproducible example in this google colab. The error

Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

only occurs when trying to do a subnetwork approach, and not with the default Laplace parameters. I have also tried to change the timm resnet activation functions to not be inplace, but maybe it is also related to skip connections? Even though the error does not occur inside the Laplace library immediately, I was wondering if you had any suggestions or pointers to make this approach work.

nilsleh commented 1 year ago

I suppose this is related: https://docs.backpack.pt/en/1.4.0/use_cases/example_resnet_all_in_one.html#any-resnet-with-backpack-s-converter

wiseodd commented 1 year ago

If it's a BackPACK issue, then maybe switching backend will help. Can you try the following?

from laplace import Laplace
from laplace.curvature import AsdlGGN

la = Laplace(model, ..., backend=AsdlGGN)
runame commented 1 year ago

Was just about to suggest the same thing. However, if you want to use AsdlGGN together with regression you will have to install the Laplace library from source and checkout the branch integrate-latest-asdl. Also, you will have to install ASDL from source and just use the master branch. Let us know if you run into any issues with this!

nilsleh commented 1 year ago

Thanks for the recommendation. I believe I installed as you suggested @runame, however, I get laplace/curvature/asdl.py", line 132, in diag fisher_maker = get_fisher_maker(self.model, cfg, self.kfac_conv) TypeError: get_fisher_maker() takes 2 positional arguments but 3 were given .

runame commented 1 year ago

Ah right, can you try to remove the argument self.kfac_conv to get_fisher_maker?

Edit: I also just fixed this on the integrate-latest-asdl branch, so you can just pull again.

nilsleh commented 10 months ago

Thank you for the help, for some models I do get the following pytorch warning and I am not sure what it implies: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

From this trace:

lib/python3.9/site-packages/laplace/baselaplace.py:377: in fit
    loss_batch, H_batch = self._curv_closure(X, y, N)
lib/python3.9/site-packages/laplace/baselaplace.py:777: in _curv_closure
    return self.backend.kron(X, y, N=N)
lib/python3.9/site-packages/laplace/curvature/asdl.py:164: in kron
    f, _ = fisher_maker.forward_and_backward()
lib/python3.9/site-packages/asdl/fisher.py:116: in forward_and_backward
    self.call_model()
lib/python3.9/site-packages/asdl/grad_maker.py:258: in call_model
    self._model_output = self._model_fn(*self._model_args, **self._model_kwargs)
lib/python3.9/site-packages/torch/nn/modules/module.py:1571: in _call_impl
    self._maybe_warn_non_full_backward_hook(args, result, grad_fn)