google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Failure when running tests (due to new jax release?) #904

Closed fabian-sp closed 3 months ago

fabian-sp commented 3 months ago

While developping my branch, since today, the tests are not passing on my machine. I had run the tests two days ago successfully, so I suspect this is due to a change in jax (a new version was released on April 03).

The error message I get when running test.sh is

************* Module clipping
optax/_src/clipping.py:153:16: E1102: jnp.greater is not callable (not-callable)
optax/_src/clipping.py:235:17: E1102: jnp.greater is not callable (not-callable)

I merged the latest optax main into my branch, so I assume that the failure does not come from my branch not being up-to-date.

vroulet commented 3 months ago

Hello @fabian-sp,

I've upstreamed the bug internally to the JAX team, I'll keep you posted. Thanks for all the work you've done on Momo. Looking forward to benchmark it!

vroulet commented 3 months ago

Fixed by #908. Thanks again!