aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Add `asdl_fisher_kwargs` argument #134

Closed runame closed 11 months ago

runame commented 11 months ago

Add asdl_fisher_kwargs argument to be able to specify ASDL's FisherConfig values like kfac_linear, diag_A/B, and autocast.

@AlexImmer Can you check that my changes in line 848 and 852 of baselaplace.py don't break anything? Not sure why this was differently implemented here than in other places.

~Also, I have noticed that the tests are broken, which is really weird since I'm almost certain that I ran them successfully for my last PR. I went back to this commit and they are already broken at this point. We will have to look into this.~

~Edit: Ok, just checked and this PR seems to break a few more tests than the ones that were broken already, so maybe do not merge yet, will fix this tomorrow.~

runame commented 11 months ago

Everything fixed.

wiseodd commented 11 months ago

Would it make sense to just include asdl_kwargs as part of backend_kwargs?

runame commented 11 months ago

I think it is better to keep them separate; otherwise we would have to add all potential FisherConfig args to the backend signatures or extract the args somehow (see this line). But I'm happy to be convinced otherwise if you have something concrete in mind.

wiseodd commented 11 months ago

Then LGTM! @AlexImmer all yours now

aleximmer commented 11 months ago

LGTM. I think we eventually have to clean it up a little though before we bring it into main. For example, asdl config should be part of the backend config instead of being passed to Laplace, which has to abstract away from the backends.