LouisDesdoigts / zodiax

Object-oriented Jax framework extending Equinox for scientific programming
https://louisdesdoigts.github.io/zodiax/
BSD 3-Clause "New" or "Revised" License
11 stars 1 forks source link

Enable gradients through filter function wrt floats #33

Open LouisDesdoigts opened 1 year ago

LouisDesdoigts commented 1 year ago

Currently inside the wrappers, no args are passed to the call of eqx.filter_, which defaults to eqx.is_array (I believe), excluding floats.

We should be able to fix this with is_array_or_float = lambda leaf: True if isinstance(leaf, float) else eqx.is_array(leaf) and passing that to the equinox call.