dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
822 stars 67 forks source link

Could `opt_einsum` understand repeated inputs? #194

Open romanngg opened 2 years ago

romanngg commented 2 years ago

Hello and thank you for the great library!

I'm curious if opt_einsum can be generalized to let the user specify which inputs are the same, and use this info to produce a more optimal contraction?

Example:

import numpy as np

A = np.random.normal(size=(3, 2))
B = np.random.normal(size=(2, 2))

def f(A, B):
  AB = A @ B
  return AB @ AB.T

def f_einsum(A, B):
  return np.einsum('ij,jk,lk,zl->iz', A, B, B, A, optimize='optimal')
import opt_einsum
opt_einsum.contract_path('ij,jk,lk,zl->iz', A, B, B, A, optimize='optimal')

gives

([(1, 2), (0, 2), (0, 1)],   Complete contraction:  ij,jk,lk,zl->iz
          Naive scaling:  5
      Optimized scaling:  3
       Naive FLOP count:  2.880e+2
   Optimized FLOP count:  7.600e+1
    Theoretical speedup:  3.789e+0
   Largest intermediate:  9.000e+0 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    3           GEMM              lk,jk->lj                          ij,zl,lj->iz
    3           GEMM              lj,ij->li                             zl,li->iz
    3           GEMM              li,zl->iz                                iz->iz)

i.e. doing 3 contractions instead of two (where in this case evaluating f in two contractions is indeed faster than evaluating f_einsum). I wonder if it's feasible to accept a list of input identifiers (in this case [0, 1, 1, 0]) and leverage it to compute the contraction faster? Thank you for consideration!