dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

A failure of raw FLOP count optimization #103

Open dgasmith opened 5 years ago

dgasmith commented 5 years ago

A nice example of where optimal performs worse than greedy in practice. optimal (9s), greedy (1.1s), optimal, (10**7) (0.8s).

import numpy as np

x = np.random.randn(100, 300, 75, 10)
w = np.random.randn(100, 300, 3)
At = np.random.randn(8, 10)
G = np.random.randn(10, 3)
Bt = np.random.randn(10, 10)

np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")
np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="greedy")
np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize=("optimal", 10**7))

Copied my response here:

import opt_einsum as oe
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")[1]

 Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  5.072e+9
   Theoretical speedup:  532.355
  Largest intermediate:  2.250e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4           GEMM            fcr,nr->fcn                  mn,nh,octh,fcn->oftm
   5           GEMM          octh,nh->octn                     mn,fcn,octn->oftm
   5              0         octn,fcn->otnf                         mn,otnf->oftm
   5           TDOT          otnf,mn->oftm                            oftm->oftm
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="greedy")[1]

  Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  6
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  1.412e+10
   Theoretical speedup:  191.286
  Largest intermediate:  2.250e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   6           TDOT        octh,fcr->othfr                  mn,nr,nh,othfr->oftm
   6           TDOT        othfr,nh->otfrn                     mn,nr,otfrn->oftm
   5              0         otfrn,nr->otfn                         mn,otfn->oftm
   5           GEMM          otfn,mn->oftm                            oftm->oftm
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal", memory_limit=int(10**7))[1]

  Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  6
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  3.601e+10
   Theoretical speedup:  74.970
  Largest intermediate:  6.000e+6 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3              0             nr,mn->nrm                 fcr,nh,octh,nrm->oftm
   4           GEMM            nrm,nh->rmh                    fcr,octh,rmh->oftm
   5           GEMM          rmh,fcr->mhfc                       octh,mhfc->oftm
   6           TDOT        mhfc,octh->oftm                            oftm->oftm

Any BLAS=0 call will call np.einsum (for an einsum backend).

Using optimal without a memory limit your most expensive (by far) contraction involving the oc indices defaults to an einsum operation. For every other contraction path the most expensive operations are handled by GEMM (best no tensor copies) or TDOT (requires tensor copies).

We have looked at adding additional logic like so:

This gets a bit tricky as the heuristics become very blurry for several reasons:

Our original pass at smarter heuristics for einsum was not met with too much enthusiasm as the use cases for einsum are incredibly diverse. If we could define a 99% case we could likely optimize to it, but so far we haven't been successful in describing the bounds of those cases.

Original issue here: https://github.com/numpy/numpy/issues/14332

shoyer commented 5 years ago

I wonder if the right way to handle this for now is to add a low level API for indicating costs, either in terms of a multiple of the original cost or with custom logic for particular shapes.

Google TPUs also have dedicated matrix multiplication units (like NVidia's tensor cores), and I suspect we will see even more hardware like this in the future. The logic gets pretty specialized to particular platforms, so I think it would be difficult to handle all the options ahead of time.

dgasmith commented 5 years ago

Would it be better to allow a scaling factor for different operational types? I think only the calling function would know if the operation could fit specific TPU requirements.

costs = {"EINSUM": 1.25,
         "GEMM": 0.5 / num_threads,
         "TDOT: 0.5 / num_threads,
         "TDOT_TRANSPOSE": 0.75}
oe.contract_path(einsum_string, *views, costs=costs)

This would allow the cost function to be simple: total_cost += costs.get(ctype, 1.0) * contraction_cost

jcmgray commented 5 years ago

My preference would be to make the default 'cost' calculated as simple as possible and just the number of operations, i.e.:

sum(
    compute_size_by_dict(indices_involved, size_dict)
    for indices_involved in contractions
)

and thus, at least as a baseline, ignore the current modifiers in the current flop_count for whether it is an inner product or how many terms are involved. This is the cost that the 'dp' optimizer minimizes and is used elsewhere in tensor network literature. Also just in practice I find it is the best estimator!

Another motivation is that for e.g. complex data types, addition is 2 FLOPs whilst multiplication is 6, and there is likely other instruction set optimized stuff going on, so e.g. just doubling for a inner product isn't necessarily natural!

Maybe one could then have a separate and more advanced FLOP/cost estimator that take into account the nature of each contraction, and other customizable factors like you mention. This would only really help to understand the cost of a contraction once it is found, but otherwise it might be a low of work to support a custom cost in all the current path finders.

dgasmith commented 5 years ago

Its a good point on the inner product these days. When this was first started and the FLOP code was written FMA was pretty bleeding edge and not generally available. The first Skylake Intel CPUs came out that year (AMD had it in 2013, but wasn't very common) and FMA/AVX has propagated to most hardware these days so that choice is now fairly wrong.

Long way of saying that that this does need an overhaul. There other question to answer first is "does this matter"? We can think about several regimes:

All of these use cases require scaling to be injected into the path finding itself, the logic overhead would slow the algorithms down quite a bit as well. It may be worth hacking in something exploratory to check the above matter, this could be as simple as being able to supply your own FLOP cost function.

If we rename FLOPs to OPs I think we become less clear as an OP could refer to a SIMD (or similar) call. Is there a better way to phrase the current "FLOP" categories?