proteneer / timemachine

Differentiate all the things!
Other
140 stars 17 forks source link

Consider unifying approach to PRNG state #980

Open maxentile opened 1 year ago

maxentile commented 1 year ago

Thanks to @mcwitt for thoughtful comments: migrating from https://github.com/proteneer/timemachine/pull/978#discussion_r1140414877 . Would be good to discuss and adopt project-wide conventions, if possible.

Some approaches currently used:

Some possible trade-offs:

See also: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

mcwitt commented 1 year ago

Attempting to map out a decision tree:

  1. Nondeterministic functions accept PRNG state. This is the approach taken by JAX. Advantages: most explicit; avoids global state / side effects; allows for parallelization, sampling independent chains without creating extra instances. Disadvantages: significant refactoring and some boilerplate code to pass around random state arguments
    1. Use numpy API (numpy.random.Generator)
    2. Use JAX API (jax.random.PRNGKey). Advantages: compatible with JAX transformations (jit, etc.) Disadvantages: opens up unique class of potential errors: forgetting to split PRNG state
  2. Instances with nondeterministic methods maintain their own PRNG state. Seed or initial PRNG state is passed to the constructor. Advantages: avoids global state; reduced refactoring / boilerplate compared with option (1). Disadvantages: awkward / inefficient to parallelize sampling or sample independent chains (requires constructing many instances that differ only in seed, or additional interface to mutate the random state of an instance); not compatible (?) with JAX
  3. Use global PRNG state. This is the approach taken by the original numpy.random API. Advantages: simple; no boilerplate, trivial refactoring. Disadvantages: global state; need to be very careful about side effects, e.g. setting seed in one place unintentionally affecting downstream results; not compatible (?) with JAX
maxentile commented 1 year ago

Thanks for mapping these out.

By default I lean towards (2) due to familiarity, imposing looser requirements (maybe one class uses cuRAND, a different class only uses numpy, ...), and since it seems harder to make certain kinds of errors (forgetting to split keys etc.). But (1) does seem cleaner.

An additional practice that may be compatible with all of the above options is to implement a random function random_f(x) by composing a deterministic function f(x, gaussian_noise) (whose implementation is RNG-agnostic) with random generation gaussian_noise = rng.normal(0, 1, x.shape) or gaussian_noise = np.random.randn(*x.shape) or ...

mcwitt commented 1 year ago

and since it seems harder to make certain kinds of errors (forgetting to split keys etc.)

That's a good point that forgetting to split keys would be a class of error unique to (1b) (added to "disadvantages" above).

An additional practice that may be compatible with all of the above options is to implement a random function random_f(x) by composing a deterministic function f(x, gaussian_noise) (whose implementation is RNG-agnostic) with random generation

This seems similar in spirit to option (1) to me, but does have the benefit of being RNG-agnostic. It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.

maxentile commented 1 year ago

It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.

That's true, and this is also a realistic concern in the context of reweighting, where we might have a deterministic function f(theta, samples_from_theta_0), expecting input randomness that may be very complicated, expensive, failure-prone to generate.

maxentile commented 1 year ago

So it doesn't get lost, some further observations from @mcwitt in : https://github.com/proteneer/timemachine/pull/1128#discussion_r1317571529