aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

Make `matrix.py` capable handling weight-sharing dim #253

Open wiseodd opened 1 day ago

wiseodd commented 1 day ago

Currently KronLaplace doesn't work with arbitrary batch shape since matrix.py checks for at most 3-dim tensor. This is problematic when multiplying with Jacobians of shape (..., n_classes, n_params).

See https://github.com/aleximmer/Laplace/blob/glm-multidim/examples/lm_example.py for use case.

@runame, @aleximmer, any pointer on what to do? Is the fix simply to reshape into (-1, n_classes, n_params)?

https://github.com/aleximmer/Laplace/blob/e0a68e56d1f69ed393bbc870bd8da969f0c3d14a/laplace/utils/matrix.py#L208-L215