aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
432 stars 61 forks source link

Handle logits with more than 2 axes in `torch.func` diag and full backends #178

Open wiseodd opened 2 months ago

wiseodd commented 2 months ago

Currently, we still assume that logits.shape == (batch_size, n_classes). E.g.: https://github.com/aleximmer/Laplace/blob/a4d3ed631c87ddb55a1bc7aa053da249180480f7/laplace/curvature/curvature.py#L332-L334