jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.77k forks source link

Making PRNG output elements consistent for different shapes #2155

Open j-towns opened 4 years ago

j-towns commented 4 years ago
In [1]: from jax import random

In [2]: k = random.PRNGKey(0)

For implementing masking we'd like the elements of

In [3]: random.uniform(k, shape=(2,))
Out[3]: DeviceArray([0.21629536, 0.8041241 ], dtype=float32)

to match the first two elements of

In [4]: random.uniform(k, shape=(4,))
Out[4]: DeviceArray([0.9653214 , 0.22515893, 0.63302994, 0.29638183], dtype=float32)

and more generally we'd like, for any shape_1, shape_2 with len(shape_1) == len(shape_2):

random.uniform(k, shape_1)[idx] == random.uniform(k, shape_2)[idx]

At the moment the three fry function (upon which the PRNGs are based) takes in a count argument, which is an array with unique elements to which threefry essentially applies an element-wise hash (it doesn't quite work like that now but it could, and assuming that it does will simplify my explanation).

One way to get the behaviour that we want is to ensure that the elements of count are consistent, regardless of the shape. I.e. we want a function to generate the count array, call it make_unique_els, such that all of the elements of make_unique_els(shape) are distinct, and

make_unique_els(shape_1)[idx] == make_unique_els(shape_2)[idx]

A very crude way to do this would be something like

In [9]: def make_unique_els(shape):
   ...:     grid = np.mgrid[tuple(slice(s) for s in shape)]
   ...:     out = np.zeros(shape)
   ...:     for dim, g in enumerate(grid):
   ...:         out = out + 32 ** dim * g
   ...:     return out
   ...:

which will generate things like:

In [10]: make_unique_els((2, 3))
Out[10]:
array([[ 0., 32., 64.],
       [ 1., 33., 65.]])

In [11]: make_unique_els((5, 4))
Out[11]:
array([[  0.,  32.,  64.,  96.],
       [  1.,  33.,  65.,  97.],
       [  2.,  34.,  66.,  98.],
       [  3.,  35.,  67.,  99.],
       [  4.,  36.,  68., 100.]])

The the first output is contained in the second just like we want. However this has a limit of size 32 for each of the dimensions before elements start to clash, which is rubbish for our random number generation.

One way around this would be to 'hash along the first dimension' of the grid generated by mgrid, something like

In [9]: def make_unique_els(shape):
   ...:     grid = np.mgrid[tuple(slice(s) for s in shape)]
   ...:     return reduce(binary_hash_function, grid)
   ...:

That would work, and you could actually use a slight variant of the existing threefry function as the hash, but there may well be a better way to do all of this, would be great to know if anyone has any suggestions.

j-towns commented 4 years ago

OK another, probably more efficient, approach is to 'grow' the random output, dimension by dimension, by doing something similar to what split does now, along each axis.

shoyer commented 4 years ago

How do you want this to generalize for shapes with more dimensions? e.g., should the output from shape=(4, 4) be consistent with both shape=(4, 2) and shape=(2, 4)? That seems hard to ensure with the "growing" strategy without doing lots of unnecessary computation.

j-towns commented 4 years ago

I think it is possible, by doing something like this (imagining that keys are single scalars instead of pairs):

def grow(key, shape):
    """
    Assume key is a scalar, which we are going to 'grow' into an output array.
    """
    for n in shape:
        count = lax.iota(n)
        # Assume threefry is vmapped to do a kind of outer product
        # between key and count, so
        #      threefry(key, count).shape = key.shape + count.shape
        key = threefry(key, count)
    return key

EDIT: This does do some unnecessary computation, is there a way to get around that?