Open wiseodd opened 1 day ago
Currently KronLaplace doesn't work with arbitrary batch shape since matrix.py checks for at most 3-dim tensor. This is problematic when multiplying with Jacobians of shape (..., n_classes, n_params).
KronLaplace
matrix.py
(..., n_classes, n_params)
See https://github.com/aleximmer/Laplace/blob/glm-multidim/examples/lm_example.py for use case.
@runame, @aleximmer, any pointer on what to do? Is the fix simply to reshape into (-1, n_classes, n_params)?
(-1, n_classes, n_params)
https://github.com/aleximmer/Laplace/blob/e0a68e56d1f69ed393bbc870bd8da969f0c3d14a/laplace/utils/matrix.py#L208-L215
Currently
KronLaplace
doesn't work with arbitrary batch shape sincematrix.py
checks for at most 3-dim tensor. This is problematic when multiplying with Jacobians of shape(..., n_classes, n_params)
.See https://github.com/aleximmer/Laplace/blob/glm-multidim/examples/lm_example.py for use case.
@runame, @aleximmer, any pointer on what to do? Is the fix simply to reshape into
(-1, n_classes, n_params)
?https://github.com/aleximmer/Laplace/blob/e0a68e56d1f69ed393bbc870bd8da969f0c3d14a/laplace/utils/matrix.py#L208-L215