Currently inside the wrappers, no args are passed to the call of eqx.filter_, which defaults to eqx.is_array (I believe), excluding floats.
We should be able to fix this with is_array_or_float = lambda leaf: True if isinstance(leaf, float) else eqx.is_array(leaf) and passing that to the equinox call.
Currently inside the wrappers, no args are passed to the call of
eqx.filter_
, which defaults toeqx.is_array
(I believe), excluding floats.We should be able to fix this with
is_array_or_float = lambda leaf: True if isinstance(leaf, float) else eqx.is_array(leaf)
and passing that to the equinox call.