jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

using JaxprTracer as index of list in lax.fori_loop #2962

Open tvieijra opened 4 years ago

tvieijra commented 4 years ago

I am working on an application where I have a list of numpy tensors with different shapes. I want to loop over this list to multiply these tensors together to a single scalar, possibly also transforming the tensor in some way at every step in the loop. I can compile this code with jit and it works fine, but the compilation time becomes exceedingly long for large lists. I want to speed up the compilation time using lax.fori_loop, but indexing the list seems not possible.

Here is a piece of toy code:

@jit
def trace(list_of_tensors):
    L = len(list_of_tensors)
    edge = np.ones((1))

    for i in range(L):
        edge = np.einsum('i,ij->j', edge, list_of_tensors[i])

    return edge

for i in range(5):
    key = random.PRNGKey(onp.random.randint(1,1000000))
    list_of_tensors = [random.normal(key, (1,2)), random.normal(key, (2,4)), random.normal(key, (4,8)), random.normal(key, (8,1))]

    print(trace(list_of_tensors))

I want to change the explicit for loop in trace with a lax.fori_loop like this

@jit
def trace(list_of_tensors):
    L = len(list_of_tensors)
    edge = np.ones((1))

    list_of_tensors, edge = lax.fori_loop(0, L, loop_body, (list_of_tensors, edge))

    return edge

def loop_body(i, args):
    L = len(args[0])
    edge = np.einsum('i,ij->j', args[1], args[0][i])

    return (args[0], edge)

for i in range(5):
    key = random.PRNGKey(onp.random.randint(1,1000000))
    list_of_tensors = [random.normal(key, (1,2)), random.normal(key, (2,4)), random.normal(key, (4,8)), random.normal(key, (8,1))]

    print(trace(list_of_tensors)[1])

When running this code, I get the error TypeError: list indices must be integers or slices, not JaxprTracer because I want to use the loop index as an index of a list. I tried casting the loop index to a jax.ops.index or an integer but none of these work. Is there another way to use the loop index as an index of the list?

skye commented 4 years ago

To slightly sidestep your question, check out lax.scan. I think you should be able to express your original for-loop more easily using scan than fori_loop.

To more directly answer your question, there's no way to use the index from a fori_loop to index a list. Basically, the body of a fori_loop needs to be expressible as a single traced jax computation, and since jax tracers don't "see" the indexing of a regular Python list, it can't trace the list access properly. It would work if you were indexing a jax array instead of a Python list.

mattjj commented 4 years ago

Interestingly, this isn't just a tracing issue: in XLA HLO there's no way to use a dynamic value (like a loop iteration count) to index into a tuple (i.e. the only product type in HLO), in part because tuples can have elements with different shapes and the shape of every intermediate must be decidable in the type system. In other words, there's no dynamic version of GetTupleElement.

Unfortunately that means even if we could trace this computation effectively in Python (which I'm sure we could work out how to do) we don't have a way to lower it to a single compiled computation.

The only option now seems to be to pad and mask things into one shape so that you can stack things into an array. We're working on a transformation jax.mask to handle that automatically for you, and this is a good example use case that wasn't on our radar, but it's still a prototype and not ready for use at the moment.

I'll self-assign this because I'm working on jax.mask (with several others too), but I likely won't be able to report progress for several weeks or more.