jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

Linearity (or affineness?) checking for PRNG keys #192

Open j-towns opened 5 years ago

j-towns commented 5 years ago

This is an issue to request (and track progress on) enforcement of linearity (in the sense of linear types) for PRNG keys. That is, enforcement of the condition that each instantiated PRNG key must be consumed (either mapped to pseudo-random numbers or split) exactly once.

This enhancement would significantly improve the usability of JAX's PRNG. IMO it is currently far too easy to unwittingly produce identical/correlated random values, which is likely to be a common source of bugs and confusion for users.

EDIT: do we want PRNGs to be consumed exactly once or is it sufficient to ensure that they are consumed no more than once? I think this property is called 'affine', see https://en.wikipedia.org/wiki/Substructural_type_system#Affine_type_systems.

samuela commented 5 years ago

I accidentally reuse PRNG keys all the time!

samuela commented 5 years ago

It occurs to me that one way to accomplish this would be to add a field on PRNG keys like consumed that is marked anytime it is consumed by a random op. Then linearity checking can be done at run-time. It requires some mutation, but it would get the job done.

mattjj commented 5 years ago

That's the first idea we had too, but it gets pretty tricky if you try to work out the details! The trouble is keeping jit as a semantic no-op.

Don't get me wrong, I'm optimistic it's doable, but it's not going to be easy.

samuela commented 5 years ago

That's the first idea we had too, but it gets pretty tricky if you try to work out the details! The trouble is keeping jit as a semantic no-op.

I'm not sure that I totally understand what you mean here but I'm guessing it's something like

rng = random.PRNGKey(0)

def foo(x):
  return random.normal(rng) * x

# rng has not yet been consumed...
foo_jitted = jit(foo)
# rng has been consumed since foo was traced.

In my experience thus far I've come to the conclusion that jitting a function that wraps over a rng is an anti-pattern; I almost always should have written the function as foo(rng, x) instead and jitted that. In some sense tracing functions that use external information corresponds to consuming that data, ie. random.normal(rng) evaluates to some constant which XLA then constant-folds (optimistically) which strikes me as consuming rng.

It occurs to me that ultimately a "solution" to this predicament will require detecting whether or not execution is running in the context of jit, much like the discussion in https://github.com/google/jax/issues/999. But unfortunately functions like random.normal(rng) don't offer any opportunities to test whether or not arguments are concrete or abstract.

j-towns commented 5 years ago

What about a decorator that ensures a function is affine in its argument(s), like

@ensure_affine(argnum=1)
def f(x, rng):
    return rng

ensure_affine could work by tracing f to a jaxpr and analysing the jaxpr to check that rng (and its descendents) are consumed (mapped to pseudo-random numbers) at most once...

samuela commented 5 years ago

@j-towns I would certainly use such a tool if it was added, but I also write a bunch of training scripts that end up having to manipulate PRNGKeys in interesting ways. It would be nice to have a solution that could also shed some light on these scripts as well. In fact, I find that most of the linearity issues end up coming from these training scripts and not the functions that I write to support them.

j-towns commented 4 years ago

Not sure how useful this is, but I discovered a way to 'abstract away' key handling, with something that looks like a global PRNG, by using generator functions. Usage looks like this:

@randomize
def bernoulli(p):
  b = (yield random.uniform) < p
  return b

@randomize
def binomial(n, p):
  bernoullis = []
  for _ in range(n):
    bernoullis.append((yield partial(bernoulli, p=p)))
  return jnp.sum(jnp.array(bernoullis))

# Can be used like this:
coin_flip = bernoulli(random.PRNGKey(0), 0.5)
sum_of_coin_flips = binomial(random.PRNGKey(2), 20, 0.5)

The definition of randomize is pretty simple:

def safe_send(generator, s):
  try: n = generator.send(s)
  except StopIteration as e: return True, e.value
  return False, n

def safe_next(generator):
  return safe_send(generator, None)

def randomize(gen_fun):
  def randomized(key, *deterministic_args, **deterministic_kwargs):
    g = gen_fun(*deterministic_args, **deterministic_kwargs)
    done, sampler = safe_next(g)
    while not done:
      key, sample_key = random.split(key)
      done, sampler = safe_send(g, sampler(sample_key))
    return sampler
  return randomized
jakevdp commented 6 months ago

2024 update: we now have a debug mode for this! You can enable it with the jax_debug_key_reuse flag:

In [1]: import jax

In [2]: with jax.debug_key_reuse(True):
   ...:     key = jax.random.key(0)
   ...:     vals1 = jax.random.uniform(key)
   ...:     vals2 = jax.random.uniform(key)  # reused!
   ...: 
---------------------------------------------------------------------------
KeyReuseError                             Traceback (most recent call last)
Cell In[2], line 3
      2 key = jax.random.key(0)
----> 3 vals1 = jax.random.uniform(key)
      4 vals2 = jax.random.uniform(key)  # reused!

File ~/github/google/jax/jax/_src/random.py:398, in uniform()
    397 shape = core.as_named_shape(shape)
--> 398 return _uniform(key, shape, dtype, minval, maxval)

KeyReuseError: PRNG key first used at the above location was subsequently reused at the following location:
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.KeyReuseError

The above exception was the direct cause of the following exception:

KeyReuseError                             Traceback (most recent call last)
Cell In[2], line 4
      2 key = jax.random.key(0)
      3 vals1 = jax.random.uniform(key)
----> 4 vals2 = jax.random.uniform(key)  # reused!

File ~/github/google/jax/jax/_src/random.py:398, in uniform(key, shape, dtype, minval, maxval)
    396 dtype = dtypes.canonicalize_dtype(dtype)
    397 shape = core.as_named_shape(shape)
--> 398 return _uniform(key, shape, dtype, minval, maxval)

    [... skipping hidden 7 frame]

File ~/github/google/jax/jax/experimental/key_reuse/_core.py:177, in KeyReuseSignature.check_signature(self, funcname, context, *args)
    175 if context:
    176   msg += " {context}"
--> 177 raise key_reuse_error_with_source_traceback(
    178     msg, key._source_info and key._source_info.traceback)

KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.KeyReuseError

You can enable it globally by running this at the top of your script: jax.config.update('jax_debug_key_reuse', True). We're hoping to enable this mode by default, but we need to do some optimization of its dispatch path before we can do so.

More info at https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html