dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

Cannot provide custom `tensordot` method #101

Closed srush closed 3 years ago

srush commented 5 years ago

As far as I can tell there is no way to provide a custom tensordot method. I tried pretty hard, but the code path seems to fail on examples like contract("abc,ab-> ac").

jcmgray commented 5 years ago

Maybe you could expand a bit on what you have tried? The equation you give as an example cannot be framed as a simple tensordot, since the index 'a' appears 3 times.

srush commented 5 years ago

Sure. The docs seem to imply that if you provide a tensordot and a transpose, that opteinsum does the rest. In practice though, when I tried this I ran into two issues.

1) The code path seems to be broken, it complains that my library has no einsum function. I think there is a bug here https://github.com/dgasmith/opt_einsum/blob/master/opt_einsum/contract.py#L541

2) When I change that line, and give it the above expression it tries to call my tensordot function and then fails with an enigmatic error (seems to get confused by the multiple instances).

In practice I love the idea that opt-einsum could be used with general operators, but it seems like maybe the docs should be updated to make it clear that it requires a full einsum implementation below the hood?

jcmgray commented 5 years ago

The main role of opt_einsum I would consider is to find and perform pairwise contraction orderings.

With regard to point 1, does your backend provide an einsum function? If it doesn't then this isn't a bug, its just that certain pairwise contractions such as the one you give cannot be done with tensordot alone. I think the docs are reasonably clear on when an einsum implementation is required:

For a contraction to be possible without using a backend einsum, it must satisfy the following rule: in the full expression (including output indices) each index must appear twice. In other words, each dimension must be contracted with one other dimension, or left alone.

There are simple cases (such as here where 'a' is just a broadcast dimension) where one could imagine manually implementing things, but einsum has such myriad syntax uses that it gets v complex to handle, let alone performance considerations.

srush commented 5 years ago

I see, I guess I got tripped up by this statement: "In fact, any library that provides a tensordot() and transpose() implementation can perform most normal contractions." Keeping axes around felt pretty normal to me (for instance any sort of batching).

Writing an einsum implementation is really hard, it requires a parser and all sort of special cases. I like the idea that opt-einsum could be used as a front-end for the non-standard contractions I want to do. Any idea for the true minimal set of functions a backend could implement such that opt-einsum could be used as a front-end?

dgasmith commented 5 years ago

This sounds a bit like we need a register_backend that would work something like the following:

oe.register_backend("mymodule", einsum=np.einsum, tensordot=mymodule.tensordot, transpose=mymodule.transpose)

oe.contract(..., backend="mymodule")

We should only need einsum, tensordot, and transpose for full functionality.

Another good time to look into NEP18.

jcmgray commented 5 years ago

@dgasmith I think the issue is just the type of contraction requiring a non-existent einsum.

I see, I guess I got tripped up by this statement: "In fact, any library that provides a tensordot() and transpose() implementation can perform most normal contractions." Keeping axes around felt pretty normal to me (for instance any sort of batching).

Yes I guess 'normal' is pretty subjective! I was coming from the background of physics and einstein notation when I wrote that I suspect.

Any idea for the true minimal set of functions a backend could implement such that opt-einsum could be used as a front-end?

I think einsum can do things no other function can, but some extras that come to mind are:

But once they start getting mixed together, e.g. (iij->i) it all gets a bit more complicated and complete support seems pretty impossible. Trying to support specific, common cases such as batch dimensions would probably be the way to proceed (might actually be possible with #95), but even then, there is not standard 'vectorize' function in the same way that tensordot is 'standard', so one would have to use python looping and concat or something.

dgasmith commented 3 years ago

I would lean with @jcmgray that this functionality could be complex to implement. At the moment, you can hack this with the following:

import opt_einsum as oe
oe.backends.dispatch._cached_funcs["tensordor", "numpy"] = custom_func

Happy to take a PR which is well documented with edge cases.