[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted jax.tree.map(f, None, x) where x is not None, effectively treating None as if it were pytree-prefix of any value. But None is a pytree container, and it is only a prefix of None itself.
Fix user code that was relying on this bug. Most commonly, the fix is to write
jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None).
[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted
jax.tree.map(f, None, x)
wherex
is notNone
, effectively treatingNone
as if it were pytree-prefix of any value. ButNone
is a pytree container, and it is only a prefix ofNone
itself.Fix user code that was relying on this bug. Most commonly, the fix is to write
jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)
.