patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Incompatibility with jax 0.4.34 (and 0.4.35) #892

Closed Prsodk closed 2 weeks ago

Prsodk commented 2 weeks ago

After making a new environment, I got the following errors. After downgrading to jax 0.4.33, the code worked again. The problem is most likely a problem in Optax 0.2.3 (latest version available with pip), possibly connected with weight decay.

jax_error.txt

I have not reported at optax.

:)

patrick-kidger commented 2 weeks ago

Looks like this is a known issue with Optax, and already fixed at latest HEAD: https://github.com/google-deepmind/optax/issues/1093

If this is blocking you then I'd suggest nudging them to ask for a new release :)

Prsodk commented 2 weeks ago

Thank you. I have installed from github now. And it works. There are some performance drawdown at the moment: mnist_training.txt

patrick-kidger commented 2 weeks ago

I've seen a few such reports for recent versions of JAX. If you can produce a MWE then I'd recommend raising the performance differences on the JAX GitHub.