Open j-towns opened 4 years ago
OK another, probably more efficient, approach is to 'grow' the random output, dimension by dimension, by doing something similar to what split does now, along each axis.
How do you want this to generalize for shapes with more dimensions? e.g., should the output from shape=(4, 4)
be consistent with both shape=(4, 2)
and shape=(2, 4)
? That seems hard to ensure with the "growing" strategy without doing lots of unnecessary computation.
I think it is possible, by doing something like this (imagining that keys are single scalars instead of pairs):
def grow(key, shape):
"""
Assume key is a scalar, which we are going to 'grow' into an output array.
"""
for n in shape:
count = lax.iota(n)
# Assume threefry is vmapped to do a kind of outer product
# between key and count, so
# threefry(key, count).shape = key.shape + count.shape
key = threefry(key, count)
return key
EDIT: This does do some unnecessary computation, is there a way to get around that?
For implementing masking we'd like the elements of
to match the first two elements of
and more generally we'd like, for any
shape_1
,shape_2
withlen(shape_1) == len(shape_2)
:At the moment the three fry function (upon which the PRNGs are based) takes in a
count
argument, which is an array with unique elements to which threefry essentially applies an element-wise hash (it doesn't quite work like that now but it could, and assuming that it does will simplify my explanation).One way to get the behaviour that we want is to ensure that the elements of
count
are consistent, regardless of the shape. I.e. we want a function to generate the count array, call itmake_unique_els
, such that all of the elements ofmake_unique_els(shape)
are distinct, andA very crude way to do this would be something like
which will generate things like:
The the first output is contained in the second just like we want. However this has a limit of size 32 for each of the dimensions before elements start to clash, which is rubbish for our random number generation.
One way around this would be to 'hash along the first dimension' of the grid generated by mgrid, something like
That would work, and you could actually use a slight variant of the existing threefry function as the hash, but there may well be a better way to do all of this, would be great to know if anyone has any suggestions.