Closed wiseodd closed 3 months ago
By the way, I didn't really implement nor test SubnetLaplace
. Not sure if it's really popular nowadays; the same effect can be done by setting requires_grad = False
on select parameters anyway (from #144).
Nevertheless, @edaxberger feel free to implement serialization on SubnetLaplace
. It should be very straightforward for you.
By the way, I didn't really implement nor test
SubnetLaplace
. Not sure if it's really popular nowadays; the same effect can be done by settingrequires_grad = False
on select parameters anyway (from #144).Nevertheless, @edaxberger feel free to implement serialization on
SubnetLaplace
. It should be very straightforward for you.
Though I must say that a quick test on SubnetLaplace
works no problem. But I don't know about edge cases like different subnetwork_indices
but the same len(subnetwork_indices)
, etc.
@runame updated. Please check and merge if everything's good.
I decided to not include backend
and backend_kwargs
since Laplace doesn't care about them once H
is obtained. I.e. H
is just a tensor or Kron
and you can use whatever backend you want for glm_predictive
and to do another fit
.
I also ignore model.state_dict
since the user would already have it anyway. Even marglik training also outputted the resulting torch model. Including it in the serialized Laplace is redundant.
I think the current checks are sufficient to handle continual learning:
This will check whether the hessian_structure
and subset_weights
are compatible with the loaded instance:
https://github.com/aleximmer/Laplace/blob/30e28f2d94590f7b8d265f68cf98489c3a7af5dd/laplace/baselaplace.py#L792-L796
This will check the network (torch model) is correct: https://github.com/aleximmer/Laplace/blob/30e28f2d94590f7b8d265f68cf98489c3a7af5dd/laplace/baselaplace.py#L797-L802
As for model.state_dict
we could add an option in Laplace.state_dict()
, say save_model: bool = False
. But this should be part of your future PR.
Added a test to check fit(override=False)
!
Addressing #45. Very useful for large models like LLMs where even doing forward passes over training data (for
fit()
) is expensive.The API basically follows PyTorch.