google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones. #983

Closed copybara-service[bot] closed 3 months ago

copybara-service[bot] commented 3 months ago

[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones.

Due to a bug in JAX, JAX previously permitted jax.tree.map(f, None, x) where x is not None, effectively treating None as if it were pytree-prefix of any value. But None is a pytree container, and it is only a prefix of None itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None).