aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Optimize `torch.func` full/diag GGN/Fisher-MC #149

Open wiseodd opened 4 months ago

wiseodd commented 4 months ago

Currently, we compute the Jacobians explicitly. We can improve this by using VJPs.

Reference for full GGN: https://github.com/f-dangel/curvlinops/blob/5852711aedf2728bc609fabfa95eac00da1beb63/curvlinops/examples/functorch.py#L72-L138

Not a high priority since KFAC is usually used and (diag/full) EF implementations are already efficient.