dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

optimal contraction path #6

Closed jjren closed 7 years ago

jjren commented 7 years ago

I tried to use opt_einsum to calculate the following tensor contraction. But the contraction path seems non-optimal.

res1 = opt_einsum.contract("abc,bdef,fghj,cem,mhk,ljk -> adgl",A,B,C,D,E,F, path='optimal')
pathinfo = opt_einsum.contract_path("abc,bdef,fghj,cem,mhk,ljk -> adgl" ,A,B,C,D,E,F, path="optimal")
The "optimal path" printed is 

[(3, 4), (0, 1, 2, 3, 4)]
Complete contraction:  abc,bdef,fghj,cem,mhk,ljk->adgl
Naive scaling:  12
Optimized scaling:  11
Naive FLOP count:  3.075e+09
Optimized FLOP count:  2.050e+08
Theoretical speedup:  15.000
Largest intermediate:  6.250e+02 elements
--------------------------------------------------------------------------------
scaling   BLAS                  current                                remaining
--------------------------------------------------------------------------------
   5      GEMM            mhk,cem->ckeh             abc,bdef,fghj,ljk,ckeh->adgl
  11     False ckeh,ljk,fghj,bdef,abc->adgl                               adgl->adgl

My question is that why the second contraction doesn't contract pairwise. T I also use tensordot to contract it pairwise. It's much faster. I guess the reason is that though in the pairwise tensordot, the total FLOPS is larger, but the speed is much faster.

dgasmith commented 7 years ago

What are the sizes of the indices? At first glance it looks like its hitting a outer-product that it doesn't want to deal with due to the size of the memory overhead. The algorithms are tuned so that it tries to keep the memory overhead within the limits of the incoming tensors.

jjren commented 7 years ago

I notice the memory limitation just now. The sizes of the indices are roughly 10~20. Can I unset this memory limitation? because the memory is not a big problem in my case.

dgasmith commented 7 years ago

Yes, it looks like I forgot to document memory_limit, but this is the maximum temporary size to be built.

jjren commented 7 years ago

Is there any method to set memory_limit unlimited?

jjren commented 7 years ago

It seems that the in the contract_path function. path_type = kwargs.pop('path_type', 'greedy') should be replaced by path_type = kwargs.pop('path', 'greedy')

# Make sure all keywords are valid
valid_contract_kwargs = ['path', 'memory_limit', 'einsum_call', 'use_blas']
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs]
if len(unknown_kwargs):
    raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs)

path_type = kwargs.pop('path_type', 'greedy')
memory_limit = kwargs.pop('memory_limit', None)
dgasmith commented 7 years ago

Hmm, thanks for bringing this up. Fixed in the master branch:

>>> expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
>>> views = opt_einsum.helpers.build_views(expression)
>>> pathinfo = opt_einsum.contract_path(expression, *views, path="optimal", memory_limit=-1)
>>> print(pathinfo[0])
[(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]

>>> print(pathinfo[1])
  Complete contraction:  abc,bdef,fghj,cem,mhk,ljk->adgl
         Naive scaling:  12
     Optimized scaling:  6
      Naive FLOP count:  2.074e+07
  Optimized FLOP count:  9.648e+03
   Theoretical speedup:  2149.254
  Largest intermediate:  2.160e+02 elements
--------------------------------------------------------------------------------
scaling   BLAS                  current                                remaining
--------------------------------------------------------------------------------
   5      GEMM            cem,abc->abem             bdef,fghj,mhk,ljk,abem->adgl
   6      TDOT          abem,bdef->afdm                  fghj,mhk,ljk,afdm->adgl
   6      TDOT          ljk,fghj->glfkh                     mhk,afdm,glfkh->adgl
   6      TDOT          glfkh,mhk->glfm                          afdm,glfm->adgl
   6      TDOT          glfm,afdm->adgl                               adgl->adgl
jjren commented 7 years ago

That's great, thanks!