Open maxentile opened 1 year ago
Attempting to map out a decision tree:
numpy.random.Generator
)jax.random.PRNGKey
).
Advantages: compatible with JAX transformations (jit, etc.)
Disadvantages: opens up unique class of potential errors: forgetting to split PRNG statenumpy.random
API.
Advantages: simple; no boilerplate, trivial refactoring.
Disadvantages: global state; need to be very careful about side effects, e.g. setting seed in one place unintentionally affecting downstream results; not compatible (?) with JAXThanks for mapping these out.
By default I lean towards (2) due to familiarity, imposing looser requirements (maybe one class uses cuRAND, a different class only uses numpy, ...), and since it seems harder to make certain kinds of errors (forgetting to split keys etc.). But (1) does seem cleaner.
An additional practice that may be compatible with all of the above options is to implement a random function random_f(x)
by composing a deterministic function f(x, gaussian_noise)
(whose implementation is RNG-agnostic) with random generation gaussian_noise = rng.normal(0, 1, x.shape)
or gaussian_noise = np.random.randn(*x.shape)
or ...
and since it seems harder to make certain kinds of errors (forgetting to split keys etc.)
That's a good point that forgetting to split keys would be a class of error unique to (1b) (added to "disadvantages" above).
An additional practice that may be compatible with all of the above options is to implement a random function random_f(x) by composing a deterministic function f(x, gaussian_noise) (whose implementation is RNG-agnostic) with random generation
This seems similar in spirit to option (1) to me, but does have the benefit of being RNG-agnostic. It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.
It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.
That's true, and this is also a realistic concern in the context of reweighting, where we might have a deterministic function f(theta, samples_from_theta_0)
, expecting input randomness that may be very complicated, expensive, failure-prone to generate.
So it doesn't get lost, some further observations from @mcwitt in : https://github.com/proteneer/timemachine/pull/1128#discussion_r1317571529
Thanks to @mcwitt for thoughtful comments: migrating from https://github.com/proteneer/timemachine/pull/978#discussion_r1140414877 . Would be good to discuss and adopt project-wide conventions, if possible.
Some approaches currently used:
Some possible trade-offs:
See also: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html