google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

haiku: fully support JAX typed PRNG keys #736

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

haiku: fully support JAX typed PRNG keys

For more details, see https://github.com/google/jax/pull/17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of jax_enable_custom_prng, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed.