Open Kasekopf opened 3 years ago
XLA isn't very well tuned for arrays with large numbers of dimensions. So this isn't particularly surprising to see. We've reported similar issues to the compiler team in the past, but it isn't a high priority for them at the moment.
Out of curiosity: is this a tensor network? I think https://github.com/google/TensorNetwork implements optimizations for efficient contractions on top of JAX that you could either use or learn from.
Failing that, if your workaround works, you should use it!
Does that answer the question sufficiently?
When trying to use a jax on a TPU, I am seeing long compile times for tensordot when its inputs have many dimensions.
For example, consider the following reimplentation of tensordot(a,b,N) using reshape and matmul:
On tensors with 12 dimensions, this tensordot2 compiles 10x faster than jax.numpy.tensordot but has very similar performance after compilation:
See the provided simple python file run.py that performs this measurement.
Some additional notes:
Is there anything I can do to improve the compile time of jax.numpy.tensordot? Or is it best to stick with my own reimplementation for now?
run.zip