Open nilsleh opened 1 year ago
If it's a BackPACK issue, then maybe switching backend will help. Can you try the following?
from laplace import Laplace
from laplace.curvature import AsdlGGN
la = Laplace(model, ..., backend=AsdlGGN)
Was just about to suggest the same thing. However, if you want to use AsdlGGN
together with regression
you will have to install the Laplace library from source and checkout the branch integrate-latest-asdl
. Also, you will have to install ASDL from source and just use the master
branch. Let us know if you run into any issues with this!
Thanks for the recommendation. I believe I installed as you suggested @runame, however, I get
laplace/curvature/asdl.py", line 132, in diag fisher_maker = get_fisher_maker(self.model, cfg, self.kfac_conv) TypeError: get_fisher_maker() takes 2 positional arguments but 3 were given
.
Ah right, can you try to remove the argument self.kfac_conv
to get_fisher_maker
?
Edit: I also just fixed this on the integrate-latest-asdl
branch, so you can just pull again.
Thank you for the help, for some models I do get the following pytorch warning and I am not sure what it implies:
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
From this trace:
lib/python3.9/site-packages/laplace/baselaplace.py:377: in fit
loss_batch, H_batch = self._curv_closure(X, y, N)
lib/python3.9/site-packages/laplace/baselaplace.py:777: in _curv_closure
return self.backend.kron(X, y, N=N)
lib/python3.9/site-packages/laplace/curvature/asdl.py:164: in kron
f, _ = fisher_maker.forward_and_backward()
lib/python3.9/site-packages/asdl/fisher.py:116: in forward_and_backward
self.call_model()
lib/python3.9/site-packages/asdl/grad_maker.py:258: in call_model
self._model_output = self._model_fn(*self._model_args, **self._model_kwargs)
lib/python3.9/site-packages/torch/nn/modules/module.py:1571: in _call_impl
self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Hi,
I would like to apply the Laplace Subnetwork approach to a timm library model (standard resnet18). I think the problem I am encountering is not unique to timm models per se, but to inplace operations? I have made a small reproducible example in this google colab. The error
Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
only occurs when trying to do a subnetwork approach, and not with the default Laplace parameters. I have also tried to change the timm resnet activation functions to not be inplace, but maybe it is also related to skip connections? Even though the error does not occur inside the Laplace library immediately, I was wondering if you had any suggestions or pointers to make this approach work.