blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Replace iterative RNG split and carry with `jax.random.fold_in` #656

Closed junpenglao closed 6 months ago

junpenglao commented 6 months ago

Following new recommendation on RNG usage within a loop.

import jax
key = jax.random.key(42)

Main pattern: Instead of doing

for step in range(10000000):
  key, subkey = jax.random.split(key, 2)
  values = jax.random.normal(subkey, [3, 5])
  # ... use values ...

do

for step in range(10000000):
  subkey = jax.random.fold_in(key, step)
  values = jax.random.normal(subkey, [3, 5])
  # ... use values ...

The main concern here is that there is an unnecessary statistical inefficiency using the split-then-carry within a loop. By doing splitting, random number is generated by iterating the hash on keys at a fixed message, so the resulting sequence might have a shorter period than you expect.

Also change jax.tree_map to jax.tree.map

codecov[bot] commented 6 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 98.87%. Comparing base (7cf4f9d) to head (b1ca8a1).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #656 +/- ## ========================================== - Coverage 98.87% 98.87% -0.01% ========================================== Files 59 59 Lines 2745 2744 -1 ========================================== - Hits 2714 2713 -1 Misses 31 31 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.