dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

Zero dimension tensors (as zero entries) lead opt_einsum slower than einsum #198

Open Geositta2000 opened 2 years ago

Geositta2000 commented 2 years ago

For contractions between tensors with various dimensions, sometimes the dimension may be zero, and I regarded it as a zero entry. Seems this is also what einsum does (refer to the code below, the result scalar is 0.0). For this type of contraction, seems opt_einsum is slower than einsum (I use the pre-computed path to get rid of computing path time), for example

import numpy as np
import timeit
import time
import opt_einsum as oe

na = nb = 2000
nc = nd = ne = 0

A = np.random.random((na,nb,nc))
B = np.random.random((nc,nd))
C = np.random.random((nd,ne))
D = np.random.random((ne,nb,na))

my_expr = oe.contract_expression('abc,cd,de,eba->', A.shape, B.shape, C.shape, D.shape, optimize = 'optimal')
#print(my_expr)

t_total = 0.
n_iter = 50
for i in range(n_iter):
    start = time.time()
    E = np.einsum('abc,cd,de,eba->', A, B, C, D)
    end = time.time()
    t_total += end - start
print('einsum time',(t_total)/n_iter)
print(E)

t_total = 0.
for i in range(n_iter):
    start = time.time()
    E = my_expr(A, B, C, D)
    end = time.time()
    t_total += end - start
print('oe time',(t_total)/n_iter)
print(E)

The result is

einsum time 6.322860717773437e-06
0.0
oe time 0.01558232307434082
0.0

such that einsum is faster than opt_einsum. May I know why in this case einsum is faster and is there any solution (I can write some code to detect zero dimension entry, but I hope it can somehow more robust)?

dgasmith commented 2 years ago

I would natively respond that the python overhead and repeated einsum calls is the root problem; however, the overall time is larger than expected in the opt_einsum route. I wonder if einsum itself has a pathway which returns zero in the cases of a zero dimension index. We could add a similar pathway where we check if any index dimension is zero, we return zero as a result.