dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
822 stars 67 forks source link

When I use opt_einsum optimizes torch.einum, the running time after optimization increases. #202

Open edwin-zft opened 1 year ago

edwin-zft commented 1 year ago
import numpy as np
import time
import torch
from opt_einsum import contract

dim = 4

x = torch.randn(6 ,4 ,4, 4)
w1 = torch.randn(1,4,dim)
w2 = torch.randn(dim,4,dim)
w3 = torch.randn(dim,4,dim)
w4 = torch.randn(dim,8,dim)
w5 = torch.randn(dim,8,dim)
w6 = torch.randn(dim,4,1)

def naive(x, w1, w2, w3, w4, w5, w6):
    return torch.einsum('bkxy,ikj,jxm,myf,fpl,lqz,zri -> bpqr', x, w1, w2, w3, w4, w5, w6)

def optimized(x, w1, w2, w3, w4, w5, w6):
    return contract('bkxy,ikj,jxm,myf,fpl,lqz,zri -> bpqr', x, w1, w2, w3, w4, w5, w6)

The respective running time:

naive
0.0005145072937011719
optimized
0.876018762588501

I want to know what caused this.Thanks!

jcmgray commented 1 year ago

Hey @edwin-zft, I get:

%%timeit
y = naive(x, w1, w2, w3, w4, w5, w6)
# 536 µs ± 4.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

vs.

%%timeit
y = optimized(x, w1, w2, w3, w4, w5, w6)
# 470 µs ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

and as a bonus:

expr = contract_expression(
    'bkxy,ikj,jxm,myf,fpl,lqz,zri->bpqr', 
    x.shape, w1.shape, w2.shape, w3.shape, w4.shape, w5.shape, w6.shape,
    optimize='dp',
)

%%timeit
y = expr(x, w1, w2, w3, w4, w5, w6)
# 72.2 µs ± 758 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

so maybe its just a warm-up issue for you, are you using timeit?

edwin-zft commented 1 year ago

Thank you for your reply! I reuse timeitfor testing

%%timeit
y = naive(x, w1, w2, w3, w4, w5, w6)
#0.00007126200944185257s (10000 loops)

vs.

%%timeit
y = optimized(x, w1, w2, w3, w4, w5, w6)
#0.00006402703002095222s (10000 loops)

The improvement of running speed after optimization is not obvious. I guess it is due to the particularity of this expression. Moreover,I tried contract_expression,but it didn't reduce the time.I want to know why.

expr = contract_expression(
    'bkxy,ikj,jxm,myf,fpl,lqz,zri->bpqr', 
    x.shape, w1.shape, w2.shape, w3.shape, w4.shape, w5.shape, w6.shape,
    optimize='dp',
)
%%timeit
y = expr(x, w1, w2, w3, w4, w5, w6)
print(timeit.timeit('y', setup="from __main__ import y",number=10000))
#0.00006920704618096352s (10000 loops)

Finally, thank you very much for your answers and your work!

jcmgray commented 1 year ago

The improvement of running speed after optimization is not obvious.

Some of the recent PRs/issues etc. in torch make it seem like they may have included path optimization themselves - including possibly a version of opt_einsum.

If I increase to dim=4000 the timings are still similar despite a theoretical speedup of 1.828e+14 (compared to doing a single einsum), which would be hard to miss... so it seems torch.einsum at least uses pairwise contractions now.

but it didn't reduce the time. I want to know why.

I don't know the intricacies of timeit, but I guess its running the path optimization to produce expr each time, despite the setup.

janeyx99 commented 1 year ago

FYI torch indeed does default to using opt_einsum if it's found in the environment.

dgasmith commented 1 year ago

FYI torch indeed does default to using opt_einsum if it's found in the environment.

Super cool!