aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Add native serialization support #148

Closed wiseodd closed 3 months ago

wiseodd commented 4 months ago

Addressing #45. Very useful for large models like LLMs where even doing forward passes over training data (for fit()) is expensive.

The API basically follows PyTorch.

la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
la.fit(train_loader)
la.optimize_prior_precision()  # Or via marglik optimizing sigma_noise also

# Serialization for fitted quantities
state_dict = la.state_dict()
torch.save(state_dict, 'state_dict.bin')

la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
# Load serialized, fitted quantities
la2.load_state_dict(torch.load('state_dict.bin'))
wiseodd commented 4 months ago

By the way, I didn't really implement nor test SubnetLaplace. Not sure if it's really popular nowadays; the same effect can be done by setting requires_grad = False on select parameters anyway (from #144).

Nevertheless, @edaxberger feel free to implement serialization on SubnetLaplace. It should be very straightforward for you.

wiseodd commented 4 months ago

By the way, I didn't really implement nor test SubnetLaplace. Not sure if it's really popular nowadays; the same effect can be done by setting requires_grad = False on select parameters anyway (from #144).

Nevertheless, @edaxberger feel free to implement serialization on SubnetLaplace. It should be very straightforward for you.

Though I must say that a quick test on SubnetLaplace works no problem. But I don't know about edge cases like different subnetwork_indices but the same len(subnetwork_indices), etc.

wiseodd commented 3 months ago

@runame updated. Please check and merge if everything's good.

I decided to not include backend and backend_kwargs since Laplace doesn't care about them once H is obtained. I.e. H is just a tensor or Kron and you can use whatever backend you want for glm_predictive and to do another fit.

wiseodd commented 3 months ago

I also ignore model.state_dict since the user would already have it anyway. Even marglik training also outputted the resulting torch model. Including it in the serialized Laplace is redundant.

wiseodd commented 3 months ago

I think the current checks are sufficient to handle continual learning:

This will check whether the hessian_structure and subset_weights are compatible with the loaded instance: https://github.com/aleximmer/Laplace/blob/30e28f2d94590f7b8d265f68cf98489c3a7af5dd/laplace/baselaplace.py#L792-L796

This will check the network (torch model) is correct: https://github.com/aleximmer/Laplace/blob/30e28f2d94590f7b8d265f68cf98489c3a7af5dd/laplace/baselaplace.py#L797-L802

As for model.state_dict we could add an option in Laplace.state_dict(), say save_model: bool = False. But this should be part of your future PR.

wiseodd commented 3 months ago

Added a test to check fit(override=False)!