Open j-towns opened 5 years ago
I accidentally reuse PRNG keys all the time!
It occurs to me that one way to accomplish this would be to add a field on PRNG keys like consumed
that is marked anytime it is consumed by a random op. Then linearity checking can be done at run-time. It requires some mutation, but it would get the job done.
That's the first idea we had too, but it gets pretty tricky if you try to work out the details! The trouble is keeping jit
as a semantic no-op.
Don't get me wrong, I'm optimistic it's doable, but it's not going to be easy.
That's the first idea we had too, but it gets pretty tricky if you try to work out the details! The trouble is keeping jit as a semantic no-op.
I'm not sure that I totally understand what you mean here but I'm guessing it's something like
rng = random.PRNGKey(0)
def foo(x):
return random.normal(rng) * x
# rng has not yet been consumed...
foo_jitted = jit(foo)
# rng has been consumed since foo was traced.
In my experience thus far I've come to the conclusion that jitting a function that wraps over a rng is an anti-pattern; I almost always should have written the function as foo(rng, x)
instead and jitted that. In some sense tracing functions that use external information corresponds to consuming that data, ie. random.normal(rng)
evaluates to some constant which XLA then constant-folds (optimistically) which strikes me as consuming rng
.
It occurs to me that ultimately a "solution" to this predicament will require detecting whether or not execution is running in the context of jit
, much like the discussion in https://github.com/google/jax/issues/999. But unfortunately functions like random.normal(rng)
don't offer any opportunities to test whether or not arguments are concrete or abstract.
What about a decorator that ensures a function is affine in its argument(s), like
@ensure_affine(argnum=1)
def f(x, rng):
return rng
ensure_affine
could work by tracing f
to a jaxpr and analysing the jaxpr to check that rng
(and its descendents) are consumed (mapped to pseudo-random numbers) at most once...
@j-towns I would certainly use such a tool if it was added, but I also write a bunch of training scripts that end up having to manipulate PRNGKey
s in interesting ways. It would be nice to have a solution that could also shed some light on these scripts as well. In fact, I find that most of the linearity issues end up coming from these training scripts and not the functions that I write to support them.
Not sure how useful this is, but I discovered a way to 'abstract away' key handling, with something that looks like a global PRNG, by using generator functions. Usage looks like this:
@randomize
def bernoulli(p):
b = (yield random.uniform) < p
return b
@randomize
def binomial(n, p):
bernoullis = []
for _ in range(n):
bernoullis.append((yield partial(bernoulli, p=p)))
return jnp.sum(jnp.array(bernoullis))
# Can be used like this:
coin_flip = bernoulli(random.PRNGKey(0), 0.5)
sum_of_coin_flips = binomial(random.PRNGKey(2), 20, 0.5)
The definition of randomize
is pretty simple:
def safe_send(generator, s):
try: n = generator.send(s)
except StopIteration as e: return True, e.value
return False, n
def safe_next(generator):
return safe_send(generator, None)
def randomize(gen_fun):
def randomized(key, *deterministic_args, **deterministic_kwargs):
g = gen_fun(*deterministic_args, **deterministic_kwargs)
done, sampler = safe_next(g)
while not done:
key, sample_key = random.split(key)
done, sampler = safe_send(g, sampler(sample_key))
return sampler
return randomized
2024 update: we now have a debug mode for this! You can enable it with the jax_debug_key_reuse
flag:
In [1]: import jax
In [2]: with jax.debug_key_reuse(True):
...: key = jax.random.key(0)
...: vals1 = jax.random.uniform(key)
...: vals2 = jax.random.uniform(key) # reused!
...:
---------------------------------------------------------------------------
KeyReuseError Traceback (most recent call last)
Cell In[2], line 3
2 key = jax.random.key(0)
----> 3 vals1 = jax.random.uniform(key)
4 vals2 = jax.random.uniform(key) # reused!
File ~/github/google/jax/jax/_src/random.py:398, in uniform()
397 shape = core.as_named_shape(shape)
--> 398 return _uniform(key, shape, dtype, minval, maxval)
KeyReuseError: PRNG key first used at the above location was subsequently reused at the following location:
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.KeyReuseError
The above exception was the direct cause of the following exception:
KeyReuseError Traceback (most recent call last)
Cell In[2], line 4
2 key = jax.random.key(0)
3 vals1 = jax.random.uniform(key)
----> 4 vals2 = jax.random.uniform(key) # reused!
File ~/github/google/jax/jax/_src/random.py:398, in uniform(key, shape, dtype, minval, maxval)
396 dtype = dtypes.canonicalize_dtype(dtype)
397 shape = core.as_named_shape(shape)
--> 398 return _uniform(key, shape, dtype, minval, maxval)
[... skipping hidden 7 frame]
File ~/github/google/jax/jax/experimental/key_reuse/_core.py:177, in KeyReuseSignature.check_signature(self, funcname, context, *args)
175 if context:
176 msg += " {context}"
--> 177 raise key_reuse_error_with_source_traceback(
178 msg, key._source_info and key._source_info.traceback)
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.KeyReuseError
You can enable it globally by running this at the top of your script: jax.config.update('jax_debug_key_reuse', True)
. We're hoping to enable this mode by default, but we need to do some optimization of its dispatch path before we can do so.
More info at https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html
This is an issue to request (and track progress on) enforcement of linearity (in the sense of linear types) for PRNG keys. That is, enforcement of the condition that each instantiated PRNG key must be consumed (either mapped to pseudo-random numbers or split) exactly once.
This enhancement would significantly improve the usability of JAX's PRNG. IMO it is currently far too easy to unwittingly produce identical/correlated random values, which is likely to be a common source of bugs and confusion for users.
EDIT: do we want PRNGs to be consumed exactly once or is it sufficient to ensure that they are consumed no more than once? I think this property is called 'affine', see https://en.wikipedia.org/wiki/Substructural_type_system#Affine_type_systems.