google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Add optax.tree_utils.tree_random_split. #1063

Closed carlosgmartin closed 2 weeks ago

carlosgmartin commented 2 weeks ago

This exposes the formerly private function _tree_rng_keys_split in optax/tree_utils/_random.py to the public API.

I've found this to be a useful helper function for manipulation of random trees, and intend to use it for future PRs.

vroulet commented 2 weeks ago

Thanks for doing that, it's a good idea, happy to know it can be useful. (sorry for the incremental review, didn't mean to do it like that).