jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.52k stars 2.8k forks source link

Long compile times for tensordot(a, b, N) on a TPU #5346

Open Kasekopf opened 3 years ago

Kasekopf commented 3 years ago

When trying to use a jax on a TPU, I am seeing long compile times for tensordot when its inputs have many dimensions.

For example, consider the following reimplentation of tensordot(a,b,N) using reshape and matmul:

def tensordot2(a, b, N):
        """
        Reimplementation of jax.numpy.tensordot(a, b, N)
        """
        k = onp.prod(a.shape[-N:])
        a_r = jax.numpy.reshape(a, (-1, k))
        b_r = jax.numpy.reshape(b, (k, -1))
        res = jax.numpy.matmul(a_r, b_r)
        return jax.numpy.reshape(res, a.shape[:-N] + b.shape[N:])

On tensors with 12 dimensions, this tensordot2 compiles 10x faster than jax.numpy.tensordot but has very similar performance after compilation:

$ python run.py --tpu="grpc://10.220.228.98:8470"
JAX version: 0.2.7
Connecting to: grpc://10.220.228.98:8470

-- Using alternative tensordot(a,b,N) implementation--
First time: 1.0770306587219238
Output: [ 0.07692753  0.00059127 -0.02753184  0.0518856   0.08529685  0.00665864
  0.00136994 -0.03543025]
Average of next 100 trials: 0.01755237817764282

-- Using jax.numpy.tensordot(a,b,N) --
First time: 11.581943273544312
Output: [ 0.07692759  0.00059129 -0.02753186  0.05188563  0.08529684  0.00665865
  0.00136995 -0.03543027]
Average of next 100 trials: 0.01803591012954712

See the provided simple python file run.py that performs this measurement.

Some additional notes:

Is there anything I can do to improve the compile time of jax.numpy.tensordot? Or is it best to stick with my own reimplementation for now?

run.zip

hawkinsp commented 3 years ago

XLA isn't very well tuned for arrays with large numbers of dimensions. So this isn't particularly surprising to see. We've reported similar issues to the compiler team in the past, but it isn't a high priority for them at the moment.

Out of curiosity: is this a tensor network? I think https://github.com/google/TensorNetwork implements optimizations for efficient contractions on top of JAX that you could either use or learn from.

Failing that, if your workaround works, you should use it!

Does that answer the question sufficiently?