Closed mdmould closed 6 months ago
This replaces two instances of the new jax.tree API with the equivalent jax.tree_util functions. The former is only available since jax==0.4.25, so the other option would be to include >=0.4.25 in the dependencies.
jax.tree
jax.tree_util
jax==0.4.25
>=0.4.25
This replaces two instances of the new
jax.tree
API with the equivalentjax.tree_util
functions. The former is only available sincejax==0.4.25
, so the other option would be to include>=0.4.25
in the dependencies.