dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

`Optimal` path not optimal? #99

Closed Bonnevie closed 4 years ago

Bonnevie commented 5 years ago

First, thanks for the great package!

Was running optimal vs. auto on a 10 element random expression, and while I know optimal is not recommended for larger networks, I found it odd that it gives worse scaling and with a larger intermediate size. Some of the other threads mention memory as a factor in the optimization, but the auto solution seems better on memory as well? All the tensors are 3x3. Is this intentional behaviour?

auto path:

([(2, 3), (4, 8), (0, 4), (0, 4), (4, 5), (0, 2), (2, 3), (1, 2), (0, 1)],
   Complete contraction:  db,cc,fe,fe,aa,ff,fe,cb,ea,ac->d
          Naive scaling:  6
      Optimized scaling:  3
       Naive FLOP count:  7.290e+3
   Optimized FLOP count:  2.700e+2
    Theoretical speedup:  27.000
   Largest intermediate:  9.000e+0 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    2              0              fe,fe->fe         db,cc,aa,ff,fe,cb,ea,ac,fe->d
    2              0              fe,fe->fe            db,cc,aa,ff,cb,ea,ac,fe->d
    3           GEMM              cb,db->cd               cc,aa,ff,ea,ac,fe,cd->d
    2              0              ac,cc->ac                  aa,ff,ea,fe,cd,ac->d
    3           GEMM              ac,cd->ad                     aa,ff,ea,fe,ad->d
    2              0              ea,aa->ea                        ff,fe,ad,ea->d
    3           GEMM              ea,ad->ed                           ff,fe,ed->d
    3           GEMM              ed,fe->df                              ff,df->d

optimal path:

([(1, 7), (0, 8), (0, 2), (0, 1), (0, 5), (0, 3), (0, 3), (0, 2), (0, 1)],
   Complete contraction:  db,cc,fe,fe,aa,ff,fe,cb,ea,ac->d
          Naive scaling:  6
      Optimized scaling:  4
       Naive FLOP count:  7.290e+3
   Optimized FLOP count:  5.130e+2
    Theoretical speedup:  14.211
   Largest intermediate:  2.700e+1 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    2              0              cb,cc->cb         db,fe,fe,aa,ff,fe,ea,ac,cb->d
    3           GEMM              cb,db->cd            fe,fe,aa,ff,fe,ea,ac,cd->d
    3              0             aa,fe->afe              fe,ff,fe,ea,ac,cd,afe->d
    2              0              ff,fe->fe                 fe,ea,ac,cd,afe,fe->d
    2              0              fe,fe->fe                    ea,ac,cd,afe,fe->d
    3              0            afe,ea->afe                       ac,cd,fe,afe->d
    4           GEMM            afe,ac->fec                          cd,fe,fec->d
    4           GEMM            fec,cd->fed                             fe,fed->d
    3           GEMM              fed,fe->d                                  d->d) 
jcmgray commented 5 years ago

Yes this seems like a bug. The contraction has lots of 'repeated more than 2 times' indices and I wonder if there is something funny going on with calculating the FLOP cost of each of the contractions (by the way total FLOP cost is the default target, not scaling or intermediate size).

Some extra info: 'auto' for 10 tensors uses 'branch-2' if you try with 'branch-all' or 'random-greedy' you find an even better contraction.

Bonnevie commented 5 years ago

Yeah, realize this is not a particularly realistic contraction, but thought I would bring it up in case it's a sign of some deeper bug.

The reported FLOP counts are also lower for "auto"/"branch-2" so does it make sense that the bug would be in the FLOP calculation? You'd think that "optimal" during its search came across the 270 FLOP solution.

jcmgray commented 5 years ago

I think I have found the issue, the path finders cache the output indices and flops based on the input indices of the two tensors. This contraction contains tensors not uniquely defined by their indices which seems confuse things. If you turn off this caching, the optimal path finds the correct path.

Need to have a little think about how to resolve this, as I think the caching is quite helpful speed-wise.

dgasmith commented 5 years ago

I think key = (inputs[i], inputs[j], inputs) should patch this? We could also consider performing all Hadamard products first in Optimal. I am not sure if I could come up with a formal verification, but that shouldn't change the FLOP cost.

Also a good time to start thinking about running a quick greedy path to set an upper FLOP cost limit. For this problem the optimal solution becomes 20% faster using greedy flop sieving.

jcmgray commented 5 years ago

Performing all hadamard products first might be a nice solution, it certainly seems like it's usually optimal, but it might not in cases like "cc,cc->c" ?

For the greedy sieve, my feeling is that this is best left to branch-all which is essentially optimal but starting with greedy solution and iterating off that. For small contractions when optimal is most useful/likely used, an explicit greedy call adds a considerable overhead.

dgasmith commented 5 years ago

"All Hadamard products that are not the final contraction"?

This goes back to regimes, a greedy call for optimal when n=4 will probably cost quite a bit, but when n=6+ the cost is likely negligible. For this n=10 contraction the greedy path costs 0.1 ms, but save 2.6 seconds.

jcmgray commented 5 years ago

Yes need to think a little more about hadamards.

Regarding sieve, my thinking is roughly this:

  1. For small contractions, optimal shouldn't use a sieve, implying there would need to be a different optimal path finder w/ sieve (or some automatic switching - that might be OK)
  2. For large contractions, the sieve is helpful, but nowhere near as helpful as moving to 'branch-all' (which for this contraction changes the time from 28s -> 3s) which essentially is always optimal anyway, does there really need to be another path finder in-between these two?

Maybe an 'auto-hq' option that targets longer times and thus sticks with 'branch-all' for longer might be nice.

dgasmith commented 5 years ago

Yea, optimal should likely be parameterized for different size problems and have an obvious button of "no really just test everything without heuristics", "true-optimal" (ugh)? I would still argue that an initial greedy sieve is useful for optimal if someone goes down that path.

I forget, but is there an understanding of how branch-all and optimal connect together at their limits? It was my understanding that they did, but your comment seems to suggest otherwise.

I like the auto-hq idea as well.

jcmgray commented 5 years ago

Just returning the issue at hand, I think performing all de-duplications is the answer here, I haven't been able to find any contractions where this isn't part of the optimal contraction, and adding other bits to the key which caches the contraction results incurs a really heavy performance hit.

dgasmith commented 5 years ago

The extra serialization info could be quite heavy for long inputs. Another solution would be to consider something like key = (inputs[i], inputs[j], id_i, id_j). Generating those id's might be costly however.

shoyer commented 5 years ago

I've encountered a similar situation with different inputs. Here the "greedy" path has fewer FLOPs than the "optimal" path:

In [10]: n = 100

In [11]: opt_einsum.contract_path('abc,da,eb,fc,def,gd,he,if->ghi', (n, n, n), (n, n), (n, n), (n, n), (n, n, n), (n, n), (n, n), (n, n), shapes=True, optimize='optimal')
Out[11]:
([(0, 1), (0, 6), (0, 5), (0, 1), (2, 3), (0, 2), (0, 1)],
   Complete contraction:  abc,da,eb,fc,def,gd,he,if->ghi
          Naive scaling:  9
      Optimized scaling:  4
       Naive FLOP count:  8.000e+18
   Optimized FLOP count:  1.300e+9
    Theoretical speedup:  6153846153.846
   Largest intermediate:  1.000e+8 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    4           GEMM            da,abc->dbc           eb,fc,def,gd,he,if,dbc->ghi
    4           TDOT            dbc,eb->dce              fc,def,gd,he,if,dce->ghi
    4           TDOT            dce,fc->def                 def,gd,he,if,def->ghi
    4              0           gd,def->gdef                   he,if,def,gdef->ghi
    4              0          gdef,def->gef                        he,if,gef->ghi
    4           TDOT            gef,he->gfh                           if,gfh->ghi
    4           TDOT            gfh,if->ghi                              ghi->ghi)

In [12]: opt_einsum.contract_path('abc,da,eb,fc,def,gd,he,if->ghi', (n, n, n), (n, n), (n, n), (n, n), (n, n, n), (n, n), (n, n), (n, n), shapes=True, optimize='greedy')
Out[12]:
([(0, 1), (0, 6), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)],
   Complete contraction:  abc,da,eb,fc,def,gd,he,if->ghi
          Naive scaling:  9
      Optimized scaling:  4
       Naive FLOP count:  8.000e+18
   Optimized FLOP count:  1.201e+9
    Theoretical speedup:  6661115736.886
   Largest intermediate:  1.000e+6 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    4           GEMM            da,abc->dbc           eb,fc,def,gd,he,if,dbc->ghi
    4           TDOT            dbc,eb->dce              fc,def,gd,he,if,dce->ghi
    4           TDOT            dce,fc->def                 def,gd,he,if,def->ghi
    3              0           def,def->def                     gd,he,if,def->ghi
    4           GEMM            def,gd->efg                        he,if,efg->ghi
    4           GEMM            efg,he->fgh                           if,fgh->ghi
    4           GEMM            fgh,if->ghi                              ghi->ghi)
dgasmith commented 5 years ago

Thanks for the additional example here. Interesting that optimal avoid this Hadamard product and builds a larger intermediate because of it. This could be related to the above issue and strengthens that argument that we should greedily perform Hadamard products in the optimal algorithms.

@jcmgray Did you end up having thoughts of which you prefer: performing all Hadamards greedily during the optimal paths or increasing key lengths?

jcmgray commented 5 years ago

Yes here again looks like the problem is that an intermediate has the same indices as an initial tensor.

A couple of options I think:

  1. Cache based on frozenset union of the inputs positions, rather than indices.
  2. Remove the current optimal approach and point it to the 'dp' optimizer! This would require some tweaks to make 'dp' as fast for small sizes and maybe a switch allowing consideration of outer-products for an absolutely optimal path, but might be preferable overall (it's certainly a more efficient algorithm).