jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.81k forks source link

np.einsum support #37

Closed lukasheinrich closed 5 years ago

lukasheinrich commented 5 years ago

support for generic tensor contractions would cover a large class of computations and also provide a foundation for higher order operations. Perhaps jax could then also be added as a opt_einsum backend?

mattjj commented 5 years ago

Yes, definitely! And you're right that opt_einsum should be very useful here.

Broadly speaking I think we would want to use opt_einsum to translate an einsum call into a sequence of calls into lax.dot_general, which models the XLA DotGeneral HLO. But to support all of np.einsum's functionality we'd need to handle a few more cases that aren't exactly tensor contractions, like np.einsum('ii->i', A) for extracting a matrix diagonal, np.einsum('ii->', A) for computing a trace, and np.einsum('ij->', A) for a sum over axes (contracting against an implicit all-ones array, you could say).

EDIT: There are also a few annoying constraints in the current DotGeneral HLO, like the requirement that batch dimensions precede everything else, and the requirement that there is only one contracting dimension per call, which we'd have to work around in our lax_numpy.py implementation. I believe the XLA folks are planning to make DotGeneral more convenient to use in the near future, but I'm not sure when.

Let's leave this issue open to track this feature. We can probably also implement np.tensordot support at the same time.

mattjj commented 5 years ago

Please reopen if you notice any missing cases!

lukasheinrich commented 5 years ago

thanks @mattjj

I tested with some existing code for which I'd like to use jax

import jax.numpy as np
import numpy as onp

second = onp.random.uniform(size = (2,3,4))
first = onp.random.uniform(size = (2,1))
result = np.einsum('sa,shb->shab',first,second)
result.shape #(expected shape: (2, 3, 1, 4))

which throws a StopIteration error, while changing import jax.numpy as np to import numpy as np gives the expected shape

mattjj commented 5 years ago

Sounds like a bug!

mattjj commented 5 years ago

I made a fix in 997c9c5 that I believe handles the issue. I'll merge it in to master after the CI tests pass. The 'tensor product' side of the code, which your example exercises, was (and is) under-tested, in part because I assumed it would be the easiest to get right (d'oh!).

If you've got more good examples like this, it would be awesome to add them them to the einsum test file! I think the only way to be confident we've covered all the edge cases is to keep adding tests. An even better strategy would be to generate random einsum calls, like a special case of the JAX generated function tests, though that would take some time to put together.

sjperkins commented 5 years ago

Might be worth looking at the test suite added to https://github.com/dask/dask/pull/3412 for further cases to test.

mattjj commented 5 years ago

Wow, great idea. Thanks!

mattjj commented 5 years ago

I started some work in #147, including with test cases based on the dask test cases. Those tests are really great, and uncovered several issues! Thanks, @sjperkins.

mattjj commented 5 years ago

I just merged #147, which handles a lot more cases, but there are a couple remaining that fail. They fail in interesting ways, though. In one case the failure manifests nondeterministically and only in Python 3, and based on an assert check I suspect it could even be a bug in opt_einsum. In the other cases we get an abort from XLA, and so while there might also be issues in the einsum code for those cases, I think the fact that they manifest as aborts indicate that there may be XLA bugs lurking there.

The einsum support seems pretty good now, but we've got to patch up those remaining cases! Thoughts welcome.

mattjj commented 5 years ago

This one failed in XLA, gotta check into it later:

np.einsum('iji,ki,kj->j', np.ones((3, 4, 3)), np.ones((2, 3)), np.ones((2, 4)))
mattjj commented 5 years ago

Looks like that jaxlib update to 0.1.7 fixed the issue in my previous comment.