Closed copybara-service[bot] closed 7 months ago
[dinosaur] Add prng_step to RandomnessState
This facilitates avoiding the pattern of iterative splitting the same key, which has poor statistical properties. The recommended pattern for generating a new PRNG key is jax.random.fold_in(state.prng_key, state.prng_step).
jax.random.fold_in(state.prng_key, state.prng_step)
[dinosaur] Add prng_step to RandomnessState
This facilitates avoiding the pattern of iterative splitting the same key, which has poor statistical properties. The recommended pattern for generating a new PRNG key is
jax.random.fold_in(state.prng_key, state.prng_step)
.