google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

Replace deprecated `jax.tree_*` functions with `jax.tree.*` #963

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

Replace deprecated jax.tree_* functions with jax.tree.*

The top-level jax.tree_* aliases have long been deprecated, and will soon be removed. Alternate APIs are in jax.tree_util, with shorter aliases in the jax.tree submodule, added in JAX version 0.4.25.