google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.57k stars 169 forks source link

Add filtering option to tree_get and tree_set. #871

Closed copybara-service[bot] closed 4 months ago

copybara-service[bot] commented 4 months ago

Add filtering option to tree_get and tree_set.

Enables call to an optional callable that further filters values of the key to get or to set.