jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Rework comparisons in the docs of JAX vs. NumPy PRNG #24927

Open emilyfertig opened 4 hours ago

emilyfertig commented 4 hours ago

Parts of the docs (https://jax.readthedocs.io/en/latest/key-concepts.html#pseudorandom-numbers, https://jax.readthedocs.io/en/latest/random-numbers.html) imply that NumPy PRNG needs global state. NumPy currently recommends passing around an (implicitly updated) local PRNG state, and we should update our docs so the comparison is more fair. (see https://github.com/jax-ml/jax/pull/24917#discussion_r1844596185).

jakevdp commented 4 hours ago

I found the original issue: #11026