Closed lukasheinrich closed 5 years ago
Yes, definitely! And you're right that opt_einsum should be very useful here.
Broadly speaking I think we would want to use opt_einsum to translate an einsum
call into a sequence of calls into lax.dot_general
, which models the XLA DotGeneral HLO. But to support all of np.einsum's functionality we'd need to handle a few more cases that aren't exactly tensor contractions, like np.einsum('ii->i', A)
for extracting a matrix diagonal, np.einsum('ii->', A)
for computing a trace, and np.einsum('ij->', A)
for a sum over axes (contracting against an implicit all-ones array, you could say).
EDIT: There are also a few annoying constraints in the current DotGeneral HLO, like the requirement that batch dimensions precede everything else, and the requirement that there is only one contracting dimension per call, which we'd have to work around in our lax_numpy.py implementation. I believe the XLA folks are planning to make DotGeneral more convenient to use in the near future, but I'm not sure when.
Let's leave this issue open to track this feature. We can probably also implement np.tensordot
support at the same time.
Please reopen if you notice any missing cases!
thanks @mattjj
I tested with some existing code for which I'd like to use jax
import jax.numpy as np
import numpy as onp
second = onp.random.uniform(size = (2,3,4))
first = onp.random.uniform(size = (2,1))
result = np.einsum('sa,shb->shab',first,second)
result.shape #(expected shape: (2, 3, 1, 4))
which throws a StopIteration
error, while changing import jax.numpy as np
to import numpy as np
gives the expected shape
Sounds like a bug!
I made a fix in 997c9c5 that I believe handles the issue. I'll merge it in to master after the CI tests pass. The 'tensor product' side of the code, which your example exercises, was (and is) under-tested, in part because I assumed it would be the easiest to get right (d'oh!).
If you've got more good examples like this, it would be awesome to add them them to the einsum test file! I think the only way to be confident we've covered all the edge cases is to keep adding tests. An even better strategy would be to generate random einsum calls, like a special case of the JAX generated function tests, though that would take some time to put together.
Might be worth looking at the test suite added to https://github.com/dask/dask/pull/3412 for further cases to test.
Wow, great idea. Thanks!
I started some work in #147, including with test cases based on the dask test cases. Those tests are really great, and uncovered several issues! Thanks, @sjperkins.
I just merged #147, which handles a lot more cases, but there are a couple remaining that fail. They fail in interesting ways, though. In one case the failure manifests nondeterministically and only in Python 3, and based on an assert check I suspect it could even be a bug in opt_einsum. In the other cases we get an abort from XLA, and so while there might also be issues in the einsum code for those cases, I think the fact that they manifest as aborts indicate that there may be XLA bugs lurking there.
The einsum support seems pretty good now, but we've got to patch up those remaining cases! Thoughts welcome.
This one failed in XLA, gotta check into it later:
np.einsum('iji,ki,kj->j', np.ones((3, 4, 3)), np.ones((2, 3)), np.ones((2, 4)))
Looks like that jaxlib update to 0.1.7 fixed the issue in my previous comment.
support for generic tensor contractions would cover a large class of computations and also provide a foundation for higher order operations. Perhaps jax could then also be added as a
opt_einsum
backend?