Add default functorch implementation for many quantities: (in CurvatureInterface itself)
Jacobians
Batch gradients
Diag GGN (both stochastic and exact) and diag EF
Full GGN (both stochastic and exact) and full EF
Rationale for (1): BackPACK & ASDL are not actively supported anymore and are hard to extend. Note that we use Curvlinops mainly for KFACs.
Rationale for (2):functorch is a first-class PyTorch module => has the highest compatibility to any PyTorch models => futureproofing laplace-torch, less dependence to third-party backends
Together: They make laplace-torch very flexible and compatible with most architectures out there (and in the future).
With this PR, Laplace(backend=CurvlinopsBackend) supports:
Diag-GGN-exact, diag-GGN-MC, diag-EF for both classification and regression
Kron-GGN-exact, kron-GGN-MC, kron-EF for both classification and regression
Full-GGN-exact, full-GGN-MC, full-EF, full-Hessian-exact for both classification and regression
Note that previous backends only support some subsets of them.
I'd like to ask @runame to specifically review this since he's involved in Curvlinops. Specifically, for the K-FAC backend proposed here.
Changelog:
functorch
implementation for many quantities: (inCurvatureInterface
itself)Rationale for (1): BackPACK & ASDL are not actively supported anymore and are hard to extend. Note that we use Curvlinops mainly for KFACs.
Rationale for (2):
functorch
is a first-class PyTorch module => has the highest compatibility to any PyTorch models => futureproofinglaplace-torch
, less dependence to third-party backendsTogether: They make
laplace-torch
very flexible and compatible with most architectures out there (and in the future).With this PR,
Laplace(backend=CurvlinopsBackend)
supports:Note that previous backends only support some subsets of them.
I'd like to ask @runame to specifically review this since he's involved in Curvlinops. Specifically, for the K-FAC backend proposed here.