liesel-devs / liesel

A probabilistic programming framework
https://liesel-project.org
MIT License
38 stars 2 forks source link

Allow passing a `jax.random.PRNGKey` to class `EngineBuilder` #157

Closed wiep closed 8 months ago

wiep commented 8 months ago

Discussed in https://github.com/liesel-devs/liesel/discussions/152

Originally posted by **Seb-Lorek** November 14, 2023 Hey `liesel` devs. Currently I'm pondering on how to safely use parallel pseudo-random number generation in `JAX` for a simulation study. In a simulation study in general it is important to have repeatable pseudo-random numbers across multiple processes (local or distributed) such that the streams are independent. In `numpy` this can be guaranteed (with very very high prob) by [spawning](https://numpy.org/doc/stable/reference/random/parallel.html#seedsequence-spawning). In `JAX` we can either use `jax.random.split` or `jax.random.fold_in` for safe production of (pseudo) independent streams (directly or indirectly via further key derivation). However both involve passing a `jax.random.PRNGKey` and not the **seed**. Creating a iterable/list of integers and iterating over them is not safe, corresponding to the following ```python for i in range(n): rng = jax.random.PRNGKey(i) ``` See [#18211](https://github.com/google/jax/discussions/18211) and [here](https://stackoverflow.com/questions/75338838/jax-best-way-to-iterate-rngkeys) for some discussions. In order to reliably produce independent streams within a simulation study, it would be important to have the option of directly passing a `jax.random.PRNGKey` to the `EngineBuilder`. This should be quite easy to implement as far as I can see.