google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.57k stars 169 forks source link

Extend capabilities of tree_get, tree_set. #878

Closed copybara-service[bot] closed 4 months ago

copybara-service[bot] commented 4 months ago

Extend capabilities of tree_get, tree_set.

  1. Enable tree_get, tree_set to filter for the name of a named tuple in the path to a key (hence filter the name of a state in a chained transformation). This enables distinguishing for attributes identical in two different states except that the names of the states are different.

  2. Enable tree_get, tree_set to fetch or set named tuples by the name of the name tuple. This is handy to fetch a given state in the overall state of a chained optimizer.