In #144, we added support for dict-like inputs. However, in curvature/curvlinops.py and curvature/asdl.py, the input tensor is assumed to be under "input_ids" key. We should improve upon it and make the key user specified (with default "input_ids").
Note that these are the only place where an explicit input_ids key is assumed. It's used only to compute the batch size of the input tensor. So, should be straightforward to change.
In #144, we added support for dict-like inputs. However, in
curvature/curvlinops.py
andcurvature/asdl.py
, the input tensor is assumed to be under"input_ids"
key. We should improve upon it and make the key user specified (with default"input_ids"
).Note that these are the only place where an explicit
input_ids
key is assumed. It's used only to compute the batch size of the input tensor. So, should be straightforward to change.