aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Add Curvlinops backend & add default `functorch` implementations of many curvature quantities #146

Closed wiseodd closed 4 months ago

wiseodd commented 4 months ago

Changelog:

  1. Add Curvlinops backend.
  2. Add default functorch implementation for many quantities: (in CurvatureInterface itself)
    1. Jacobians
    2. Batch gradients
    3. Diag GGN (both stochastic and exact) and diag EF
    4. Full GGN (both stochastic and exact) and full EF

Rationale for (1): BackPACK & ASDL are not actively supported anymore and are hard to extend. Note that we use Curvlinops mainly for KFACs.

Rationale for (2): functorch is a first-class PyTorch module => has the highest compatibility to any PyTorch models => futureproofing laplace-torch, less dependence to third-party backends

Together: They make laplace-torch very flexible and compatible with most architectures out there (and in the future).

With this PR, Laplace(backend=CurvlinopsBackend) supports:

  1. Diag-GGN-exact, diag-GGN-MC, diag-EF for both classification and regression
  2. Kron-GGN-exact, kron-GGN-MC, kron-EF for both classification and regression
  3. Full-GGN-exact, full-GGN-MC, full-EF, full-Hessian-exact for both classification and regression

Note that previous backends only support some subsets of them.

I'd like to ask @runame to specifically review this since he's involved in Curvlinops. Specifically, for the K-FAC backend proposed here.

wiseodd commented 4 months ago

@aleximmer please give a final check and feel free to merge!