aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
458 stars 71 forks source link

Generalized dict-input mechanism #165

Closed wiseodd closed 3 months ago

wiseodd commented 5 months ago

In #144, we added support for dict-like inputs. However, in curvature/curvlinops.py and curvature/asdl.py, the input tensor is assumed to be under "input_ids" key. We should improve upon it and make the key user specified (with default "input_ids").

Note that these are the only place where an explicit input_ids key is assumed. It's used only to compute the batch size of the input tensor. So, should be straightforward to change.