Open emilyfertig opened 4 hours ago
Parts of the docs (https://jax.readthedocs.io/en/latest/key-concepts.html#pseudorandom-numbers, https://jax.readthedocs.io/en/latest/random-numbers.html) imply that NumPy PRNG needs global state. NumPy currently recommends passing around an (implicitly updated) local PRNG state, and we should update our docs so the comparison is more fair. (see https://github.com/jax-ml/jax/pull/24917#discussion_r1844596185).
I found the original issue: #11026
Parts of the docs (https://jax.readthedocs.io/en/latest/key-concepts.html#pseudorandom-numbers, https://jax.readthedocs.io/en/latest/random-numbers.html) imply that NumPy PRNG needs global state. NumPy currently recommends passing around an (implicitly updated) local PRNG state, and we should update our docs so the comparison is more fair. (see https://github.com/jax-ml/jax/pull/24917#discussion_r1844596185).