Closed pmelchior closed 1 week ago
This commit should address the problem, reported in #87, that was introduced by a change in jax 0.4.31 and another one in optax a while back.
Tested on python3.10 with jax 0.4.26 and optax 0.2.2 and on python3.13 with jax 0.4.34 and optax 0.2.3
This commit should address the problem, reported in #87, that was introduced by a change in jax 0.4.31 and another one in optax a while back.