dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
819 stars 67 forks source link

Suboptimal order when einsum contains non-repeating indices #112

Open yaroslavvb opened 4 years ago

yaroslavvb commented 4 years ago

Is there a way to get optimized path for expression with some tensors repeated?

For instance, einsum('nlp,nlq->l', B, B) can be done in O(n^3) time, but einsum opt gives a schedule that takes O(n^4) time.

einsum_string = "nlp,nlq->l"
views = oe.helpers.build_views(einsum_string, {'n': 100, 'l': 100, 'p': 100, 'q': 100})
path, path_info = oe.contract_path(einsum_string, *views)
print(path_info)

Complete contraction:  nlp,nlq->l
         Naive scaling:  4
     Optimized scaling:  4
      Naive FLOP count:  2.000e+8
  Optimized FLOP count:  2.000e+8
   Theoretical speedup:  1.000
  Largest intermediate:  1.000e+2 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4              0             nlq,nlp->l                                  l->l

Faster version way to do this

np.square(np.einsum('nlp->nl', B)).sum(dim=0)
jcmgray commented 4 years ago

At the moment there is no way to do these non-linear types of contraction, only those that can be broken up into pairwise tensordot, transpose and einsum itself!

yaroslavvb commented 4 years ago

Why doesn't opt_einsum discover the following schedule? It's 1000x faster (colab)

x = np.einsum('nlp->nl', B)
np.einsum('ni,ni->i', x, x)

BTW, this einsum comes up when trying to extend https://arxiv.org/abs/1510.01799 to conv2d layers

dgasmith commented 4 years ago

Currently opt_einsum only concerns itself with contract order and not the optimization of individual pair contractions. There is some notions that we could do this kind of optimization here, but the technology becomes less general quickly. Consider the following contraction aij,ajk->aik which for when a is large is best written as a looped GEMM. However on a CPU the threshold for the size of a is much smaller than for a GPU where the instantiation of different routines is higher unless batched GEMMs are called.

Happy to chat about adding this, but at the moment I do lean a bit towards downstream technologies implementing this kind of optimization themselves. I think we can do a decent job in NumPy, but less so with GPU backends.

jcmgray commented 4 years ago

I was just about to comment pretty much the same as @dgasmith. The only thing I'll add is that in this specific case, it might be worth considering these kind of trivial axis reductions (i.e. p and q only appear on a single tensor) among others during some kind of preprocessing step. Then the path

'nlq->nl'
'nlp->nl'
'nl,nl->l'

would only be ~twice as slow as the 'fast' version (since with the numpy backend at least it still wouldn't know B and B are the same).

dgasmith commented 4 years ago

@jcmgray Good point there, this could fall under the "things we should always check for" like the Hadamard product issue that is coming up.

This would be cheap collapse = set('nlq') - set(all_other_extant_indices); if collapse: ....

Hey in 3-4 years we can use the walrus operator here :)

jcmgray commented 4 years ago

@jcmgray Good point there, this could fall under the "things we should always check for" like the Hadamard product issue that is coming up.

Yes exactly, might be worth compiling a list of such steps in another issue:

dgasmith commented 4 years ago

Yup, lets split this out into an issue and see how hard it would be to add in a uniform manner.

yaroslavvb commented 4 years ago

I'm wondering if this can be handled in a general fashion by trying to minimize the scaling order. Without knowledge of underlying backend, O(n^3) schedule should be preferable to O(n^4)

Some examples:

Similar problem comes up in graphical models literature and is typically handled with a two stage approach (Junction Tree algorithm). Make a graph with each index corresponding to a vertex, with indices connected if they co-occur in the same factor, then:

Step 1. triangulate the graph using a greedy heuristic Step 2: find minimum weight spanning tree of the clique graph

Each clique corresponds to an intermediate term while spanning clique-tree gives a reduction schedule.

An example of doing this for partition function of 3x3 grid Ising model 12,23,45,56,78,89,14,25,36,47,58,69-> image

jcmgray commented 4 years ago

In both those cases I think it might just be as simple as performing the individual reductions in any order? e.g. in the first case

'a->'
'b->'
',->'  # i.e. scalar multiplication

or in the second case

'abcd->a'
'aefg->a'
'a,a->a'

as all the indices fall under the category of 'appear on a single input and not the output'. Maybe there is a more general example?

yaroslavvb commented 4 years ago

A slightly more general example is ab,ac->

There's a choice over which indices to reduce last. Choosing a to be the last index to reduce splits einsum into two independent problems

np.einsum('a,a->', np.einsum('ab->a', A), np.einsum('ac->a', B))

jcmgray commented 4 years ago

So I think the 'single index reduction' preprocessing step I'm imagining would handle that fine:

  1. Find all indices that appear exactly once and sum those axes (using einsum or actual backend.sum) 2 .Then give these preprocessed inputs to the actual path finders (which would then find path = [(0, 1)].

If there are more terms, I'm pretty sure opt_einsum will never itself leave indices appearing only once so it should just be a matter of processing the inputs individually.

yaroslavvb commented 4 years ago

Summing out leaf indices may create new leaf indices, so this preprocessing step may need to be repeated to convergence. A more general example is a binary tree which is doable with O(n^2) scaling but gets O(n^3) currently.

import opt_einsum as oe
import numpy as np

def binary_tree_einsum(depth):
    edges = []

    def tc(num): return chr(num+100000) if chr(num) in ' ,->.' else chr(num)
    def rec(parent, child, depth):
      edges.append(tc(parent)+tc(child))
      if depth > 0:
          rec(child, 2*child, depth-1)
          rec(child, 2*child+1, depth-1)
    rec(0, 1, depth)
    print(oe.contract_path(','.join(edges)+'->', *[np.ones((2,2))]*len(edges)))
binary_tree_einsum(3)

#         Naive scaling:  16
#     Optimized scaling:  3
Screenshot 2019-11-28 12 21 52
jcmgray commented 4 years ago

Ah yes but again (and not trying to be contrary - its good to think about these edge cases!) if you just perform the single axis reductions first you get the n^2 scaling again:

import opt_einsum as oe
import numpy as np
from collections import Counter
from itertools import chain

def binary_tree_einsum(depth):
    edges = []

    def tc(num): 
        return chr(num+100000) if chr(num) in ' ,->.' else chr(num + 100)

    def rec(parent, child, depth):
        edges.append(tc(parent)+tc(child))
        if depth > 0:
            rec(child, 2*child, depth-1)
            rec(child, 2*child+1, depth-1)
    rec(0, 1, depth)

    views = [np.ones((2,2))]*len(edges)

    # explicitly reduce 
    freqs = Counter(chain(*edges))
    new_terms = []
    new_views = []
    for term, view in zip(edges, views):
        new_term = "".join(ix for ix in term if freqs[ix] != 1)
        new_terms.append(new_term)
        new_views.append(np.einsum(term + '->' + new_term, view) )

    eq = ','.join(new_terms) + '->'

    print(oe.contract_path(eq, *new_views))

binary_tree_einsum(3)
# Optimized scaling:  2

I'm pretty certain opt_einsum will never create new leafs. You can imagine what it does as a embedding of the hypergraph describing the original graph into a tree, such that once all edges have 'met' they annihilate. Here's a pic I drew for a paper:

hyper-congestion

(sorry may not be totally clear, on the left the dashed lines are the contraction order, on the right the graph is deformed into the tree, with the light grey nodes the intermediates).

The problem at the moment is that the path finders assume that all edges appear more than once so it doesn't 'annihilate' any leaf indices at the beginning which it should. But after that, if ever the final two indices meet they are indeed contracted so there will never be a leftover singleton.

yaroslavvb commented 4 years ago

Ah nice! That solution seems to work.

PS: I was curious to check if the optimizer will perform well for graphs with bounded treewidth, but unbounded pathwidth, but being limited to ascii makes it a bit hard to generate such graphs programmatically

jcmgray commented 4 years ago

Yes I'm not really sure what that might look like to be honest! There is a general result by Markov & Shi linking the asymptotic scaling to the treewidth of the line graph, and whilst that is an optimal result, practically speaking when this is bounded the graphs are also easier with heuristic methods.

I might mention that there are actually several different optimisers, with the default auto mode selecting one based on size of the contraction. The dynamic-programming optimizer (optimize='dp') is virtually optimal and searches through connected subgraphs, so it can address pretty large contractions (~ many 10s) when the underlying graph is structured - e.g. planar or tree-like. These are also the types of graphs that the 'random-greedy' optimizer performs well on, and that has no size limit really. Finally, I'm actually working on some very high quality contractors for large complex graphs that will be compatible with opt_einsum but these are not public quite yet...

yaroslavvb commented 4 years ago

BTW, duplicating each factor seemed like an easier workaround, however, it doesn't always fix the scaling problem.

It recovers n^2 scaling for tree with 8 elements, but still gives me n^3 scaling for tree with 16 elements. Is this an issue of suboptimal greedy optimizer kicking in?

import opt_einsum as oe
import numpy as np

def binary_tree_einsum(depth):
    edges = []

    def tc(num): return chr(num+100000) if chr(num) in ' ,->.' else chr(num)
    def rec(parent, child, depth):
      edges.append(tc(parent)+tc(child))
      edges.append(tc(parent)+tc(child))
      if depth > 0:
          rec(child, 2*child, depth-1)
          rec(child, 2*child+1, depth-1)
    rec(0, 1, depth)
    print(oe.contract_path(','.join(edges)+'->', *[np.ones((2,2))]*len(edges)))
binary_tree_einsum(2) # O(n^2)
binary_tree_einsum(3) # O(n^3)

Tree-like structures should be easy to discover -- greedy triangulation with minfill heuristic should just work.

If I add even more edges, it drops back to O(n^2) scaling. ❓

import opt_einsum as oe
import numpy as np

def binary_tree_einsum(depth):
    edges = []

    def tc(num): return chr(num+100000) if chr(num) in ' ,->.' else chr(num)
    def rec(parent, child, depth):
      edges.append(tc(parent)+tc(child))
      edges.append(tc(parent)+tc(child))
      edges.append(tc(parent)+tc(child+10000))
      edges.append(tc(parent)+tc(child+10000))
      if depth > 0:
          rec(child, 2*child, depth-1)
          rec(child, 2*child+1, depth-1)
    rec(0, 1, depth)
    print(oe.contract_path(','.join(edges)+'->', *[np.ones((2,2))]*len(edges)))
binary_tree_einsum(3) # O(n^2)

BTW, results like Markov & Shi's seem to come up in many places. Small treewidth is the most basic condition to guarantee fast computation. A more general condition is for the problem to reduce to a computation on "minor-excuded" class of graphs (Ch. 17 of Grohe's "Descriptive Complexity, Canonisation, and Definable Graph Structure Theory" book) . For instance, counting perfect matchings is an einsum which is computable in polynomial time for planar graphs using the FKT algorithm.

Personally I'm interested in ways of computing large einsums approximately since things are already inexact due to measurement noise and floating point round-off. For minor-excluded class with bounded degree there's a polynomial time approximation algorithm by Jung, Shah. A related simpler heuristic is the Generalized distributive law. Basically you reformulate einsum in terms of equations which give exact result after k updates when factor graph is a tree. When it is not a tree, you update n*k times, and get good result for small n when edge interactions are not "too strong". This lets you deal with problems that have high treewidth either due to original einsum structure or due to large factors. The latter case can be handled by approximating large factors as products of smaller factors.

In statistical physics, "Generalized distributive law" comes up in approximating Ising free energy. When factors are restricted to be pairwise edge potentials, this algorithm gives what is known as the Bethe-Peierls approximation. When nearby vertices are merged into larger factors, this algorithm gives higher quality Kikuchi approximation.

jcmgray commented 4 years ago

Is this an issue of suboptimal greedy optimizer kicking in?

Yes, if you try optimize='random-greedy' or 'auto-hq' option for example, you get the n^2 scaling again. Note that scaling might not the best indicator of performance, and at least in this case the n^2 and n^3 paths have pretty similar estimated FLOP counts.

Interestingly, the essentially optimal 'dp' optimizer can also be used here for pretty large sizes (binary_tree_einsum(7) with 510 terms takes ~ 8sec), but only if you supply it with large dimensions like (1000, 1000) - which I guess allows it hone in on the asymptotic case.

To be honest I am really not that familiar with the wider graph theory literature and classical applications of what I might call hyper tensor networks. In many-body quantum, the approximate contraction is always based on how much 'entanglement' is in the network - essentially whether the tensors are approximately low-rank across certain partitions. And for quantum circuit simulation, essentially nothing is low-rank so the contractions are performed exactly.

In that exact case, the two following statements seem to be most practically relevant:

Anyway, I dunno if you are planning on working on any of the things you mention, but I'd certainly interested to see any results, and especially whether there are other approximation schemes that are relevant to classical quantum simulation.

yaroslavvb commented 4 years ago

Aha, optimize='dp' seems appropriate here. I've been using Carl Woll's "TensorSimplify" package to come up with tractable formulas for various neural network-related quantities, but einsum_opt seems like a more flexible tool

Note that using optimize='dp' fixes order for binary tree, but it's still suboptimal for 'abcd,aefg->a'

einsum_string = 'abcd,aefg->a'
views = oe.helpers.build_views(einsum_string, {'a': 10, 'b':10, 'c': 10, 'd': 10, 'e':  10, 'f': 10, 'g': 10})
path, path_info = oe.contract_path(einsum_string, *views, optimize='dp')
print(path_info) # O(n^7)

However, adding a "all ones" 1-d factor for every dimension, seems to recover good scaling in all the instances I found to be suboptimal: a,b,c,d,e,f,g,abcd,aefg->a

dgasmith commented 4 years ago

Can we think of edge cases here that #114 will not fix the scaling for?

yaroslavvb commented 4 years ago

It seems to work for all my example...but I'm curious why preprocessing is even needed when using exact algorithm, should optimize='dp' handle these cases automatically?

jcmgray commented 4 years ago

So 'dp' indeed does this already, but if you only give opt_einsum two arguments - like 'abcd,aefg->a' - it assumes there is no optimization to do, and simply returns the path [(0, 1)] without calling any optimizer. Also, the other optimizers don't do this already, so might be useful to add it to them.

yaroslavvb commented 2 years ago

@jcmgray

@jcmgray Good point there, this could fall under the "things we should always check for" like the Hadamard product issue that is coming up.

This would be cheap collapse = set('nlq') - set(all_other_extant_indices); if collapse: ....

Hey in 3-4 years we can use the walrus operator here :)

BTW, it is now 2022 Python 3.8 has walrus operator support and is moderately available :)