Closed Prsodk closed 2 weeks ago
Looks like this is a known issue with Optax, and already fixed at latest HEAD: https://github.com/google-deepmind/optax/issues/1093
If this is blocking you then I'd suggest nudging them to ask for a new release :)
Thank you. I have installed from github now. And it works. There are some performance drawdown at the moment: mnist_training.txt
I've seen a few such reports for recent versions of JAX. If you can produce a MWE then I'd recommend raising the performance differences on the JAX GitHub.
After making a new environment, I got the following errors. After downgrading to jax 0.4.33, the code worked again. The problem is most likely a problem in Optax 0.2.3 (latest version available with pip), possibly connected with weight decay.
jax_error.txt
I have not reported at optax.
:)