april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Benchmarks for axes order #91

Open lkct opened 1 year ago

lkct commented 1 year ago

This issue is the tracker for benchmarks on pytorch ops with different axes order.

All time counted in $\mu s$.

Notations:


Conclusions:

lkct commented 1 year ago

Book-keeping -- concatenation and indexing

1-D indexing for simple folding, 2-D indexing for mixing (include a dimension for input)

code

import itertools
import subprocess
import sys
import time
from typing import List

import torch
from torch import Tensor

from benchmark.utils.gpu_benchmark import timer

device = "cuda"

B = 128
K = 512
F = 784
f = 288

sizes = {
    "B": B,
    "K": K,
    "F": F,
    "f": f,
    "2": 2,
}

N = 200

def bench(inputs: List[str], idx_shape: str, dim: int) -> float:
    input_data = [torch.rand([sizes[x] for x in y], device=device) for y in inputs]
    idx_data = torch.randint(f, [sizes[x] for x in idx_shape], device=device)
    idx = [idx_data if i == dim else slice(None) for i in range(len(inputs[0]))]

    def proc() -> Tensor:
        result = torch.cat(input_data, dim=dim)
        result = result.__getitem__(idx)
        return result

    _, t = timer(proc)
    return t

def run(inputs: str, idx_shape: str) -> None:
    if "F" in inputs:
        dim = inputs.index("F")
        args = ([inputs], idx_shape, dim)
    else:
        dim = inputs.index("f")
        args = ([inputs, inputs.replace("f", "F")], idx_shape, dim)

    for _ in range(100):  # long warm-up is essential?
        bench(*args)
    s = 0.0
    for _ in range(N):
        s += bench(*args)
    print(f"{s / N * 1000:9.3f}", end="")

def main() -> None:
    ans = []

    for inputs in itertools.permutations("BKF"):
        for idx_shape in ("f",):
            inputs = "".join(inputs)
            result = subprocess.run(
                ["python", "bench.py", inputs, idx_shape],
                capture_output=True,
                check=False,
                text=True,
            )
            assert not result.returncode, result.stderr
            ans.append(["i", inputs, "idx", idx_shape, result.stdout])
            time.sleep(1)
            print(ans[-1], file=sys.stderr)

    ans.sort(key=lambda x: x[-1])
    for item in ans:
        print(*item)

if __name__ == "__main__":
    if len(sys.argv) == 1:
        main()
    else:
        run(*sys.argv[1:])

1-concat, 1-D index

input time
BFK 793.532
KFB 794.392
FKB 831.200
FBK 831.399
KBF 844.904
BKF 845.616

2-concat, 1-D index

code differ by

    for inputs in itertools.permutations("BKf"):
input time
KFB 1041.293
BFK 1043.517
KBF 1090.914
BKF 1091.461
FKB 1104.790
FBK 1109.793

0-concat, 2-D index

code differ by

    def proc() -> Tensor:
        result = input_data[0]
        result = result.__getitem__(idx)
        return result

...

        for idx_shape in ("2f", "f2"):
input idx time
KFB 2f 322.121
KFB f2 322.174
BFK 2f 323.775
BFK f2 324.232
KBF f2 341.624
BKF 2f 341.690
KBF 2f 341.697
BKF f2 341.883
FBK 2f 440.158
FKB 2f 440.267
FBK f2 440.390
FKB f2 440.422
lkct commented 1 year ago

CP-like layers -- 2-operand einsum, 3-D input, 2/3-D param, 3-D output

2-D param for param shares among folding, 3-D for not

code

import itertools
import subprocess
import sys
import time

import torch
from torch import Tensor

from benchmark.utils.gpu_benchmark import timer

device = "cuda"

B = 128
K = 512
F = 288

R = 1
R = 128
R = 512

sizes = {
    "B": B,
    "K": K,
    "F": F,
    "R": R,
}

N = 200

def bench(inputs: str, params: str, outputs: str, order: str) -> float:
    input_data = torch.rand([sizes[x] for x in inputs], device=device)
    param_data = torch.rand([sizes[x] for x in params], device=device)

    equation_ipo = f"{inputs},{params}->{outputs}"
    equation_pio = f"{params},{inputs}->{outputs}"

    def proc_ipo() -> Tensor:
        result = torch.einsum(equation_ipo, input_data, param_data)
        return result.contiguous()  # penalize uncontiguous

    def proc_pio() -> Tensor:
        result = torch.einsum(equation_pio, param_data, input_data)
        return result.contiguous()  # penalize uncontiguous

    proc = proc_ipo if order == "ipo" else proc_pio
    _, t = timer(proc)
    return t

def run(inputs: str, params: str, outputs: str, order: str) -> None:
    args = (inputs, params, outputs, order)
    for _ in range(100):  # long warm-up is essential?
        bench(*args)
    s = 0.0
    for _ in range(N):
        s += bench(*args)
    print(order, f"{s / N * 1000:9.3f}", sep=",", end=",")

def main() -> None:
    ans = []

    for inputs in itertools.permutations("BKF"):
        for params in itertools.permutations("KRF"):  # "KR" for w/o F
            inputs = "".join(inputs)
            params = "".join(params)
            outputs = inputs.replace("K", "R")
            ans.append(["i", inputs, "p", params, "o", outputs])
            for order in ("ipo", "pio"):
                result = subprocess.run(
                    ["python", "bench.py", inputs, params, outputs, order],
                    capture_output=True,
                    check=True,
                    text=True,
                )
                ans[-1].extend(result.stdout.split(",")[:-1])
                time.sleep(1)
            print(ans[-1], file=sys.stderr)

    ans.sort(key=lambda x: min(x[-1], x[-3]))
    for item in ans:
        print(*item)

if __name__ == "__main__":
    if len(sys.argv) == 1:
        main()
    else:
        run(*sys.argv[1:])

param w/ F -- R=1

input param output ipo pio
FBK KRF FBR 113.788 119.096
FBK KFR FBR 113.789 119.216
FBK RKF FBR 113.912 119.255
FBK FRK FBR 114.691 114.610
FBK RFK FBR 114.623 114.810
FBK FKR FBR 114.643 114.854
FKB FKR FRB 117.210 117.104
FKB FRK FRB 117.171 117.119
FKB RFK FRB 117.231 117.146
FKB RKF FRB 117.993 120.959
FKB KRF FRB 118.009 120.986
FKB KFR FRB 118.018 121.041
KFB RFK RFB 118.541 118.592
KFB FKR RFB 118.595 118.558
KFB FRK RFB 118.566 118.568
KFB KFR RFB 119.701 122.484
KFB KRF RFB 119.773 122.480
KFB RKF RFB 119.831 122.524
BFK FKR BFR 119.897 119.930
BFK RFK BFR 119.958 121.007
BFK FRK BFR 119.960 120.003
BFK RKF BFR 127.243 124.769
BFK KFR BFR 127.366 124.811
BFK KRF BFR 127.302 124.972
BKF KFR BRF 1801.846 1957.580
KBF RFK RBF 1955.285 1803.298
BKF RFK BRF 1803.699 1947.905
KBF FRK RBF 1955.196 1804.161
BKF RKF BRF 1804.927 1959.471
BKF FRK BRF 1805.174 1954.135
BKF KRF BRF 1805.285 1951.911
BKF FKR BRF 1806.474 1946.075
KBF KFR RBF 1951.512 1809.683
KBF RKF RBF 1953.504 1810.281
KBF FKR RBF 1952.605 1811.003
KBF KRF RBF 1953.056 1831.348

param w/ F -- R=128

input param output ipo pio
FKB FKR FRB 332.943 271.832
FKB KFR FRB 334.048 273.649
FKB FRK FRB 350.555 290.169
FBK FKR FBR 290.291 350.455
FBK KFR FBR 291.597 353.179
FKB RFK FRB 354.897 294.839
KFB KFR RFB 367.332 329.672
KFB FKR RFB 367.710 329.767
KFB FRK RFB 385.651 347.770
BFK FKR BFR 349.022 388.739
KFB RFK RFB 388.376 349.378
BFK KFR BFR 351.018 388.244
FBK FRK FBR 403.333 464.001
FBK RFK FBR 425.401 486.442
BFK FRK BFR 476.782 515.319
BFK RFK BFR 530.967 568.218
FBK KRF FBR 1985.983 2309.411
FKB RKF FRB 2168.719 1988.884
FKB KRF FRB 2009.858 2123.959
KFB RKF RFB 2200.865 2034.394
KBF KFR RBF 2265.354 2035.632
KBF FKR RBF 2265.921 2037.660
BFK KRF BFR 2049.831 2367.700
BKF KFR BRF 2051.550 2241.513
KBF FRK RBF 2378.948 2053.770
BKF FKR BRF 2054.935 2239.223
KBF RFK RBF 2391.599 2058.657
KFB KRF RFB 2065.316 2191.561
FBK RKF FBR 2121.255 2163.598
BKF FRK BRF 2184.023 2255.856
BFK RKF BFR 2188.250 2199.440
BKF RFK BRF 2204.185 2259.588
KBF RKF RBF 4100.396 3748.783
BKF KRF BRF 3752.156 4108.266
BKF RKF BRF 3912.224 3961.133
KBF KRF RBF 3948.498 3917.926

param w/ F -- R=512

input param output ipo pio
FKB KFR FRB 1187.450 943.467
FKB FKR FRB 1180.317 948.735
FBK FKR FBR 964.116 1225.254
FBK KFR FBR 964.148 1228.849
FKB FRK FRB 1247.073 1004.016
FKB RFK FRB 1278.828 1028.722
FBK FRK FBR 1054.420 1308.914
FBK RFK FBR 1093.861 1352.954
KFB FKR RFB 1532.598 1152.051
KFB KFR RFB 1533.872 1156.363
BFK KFR BFR 1168.978 2314.418
BFK FKR BFR 1172.014 2332.021
KFB FRK RFB 1565.030 1215.263
KFB RFK RFB 1594.081 1252.846
BFK FRK BFR 1264.582 2425.236
BFK RFK BFR 1719.074 2866.844
BKF FKR BRF 3362.928 5286.904
BKF KFR BRF 3364.260 5370.066
KBF KFR RBF 3595.469 3372.718
KBF FKR RBF 3601.643 3379.303
KBF FRK RBF 3699.549 3445.736
KBF RFK RBF 3749.442 3460.732
BKF FRK BRF 3473.244 5420.293
BKF RFK BRF 3512.626 5408.833
FBK KRF FBR 11434.301 12477.709
FKB RKF FRB 12378.257 11467.728
BFK KRF BFR 11633.416 13637.055
FKB KRF FRB 11635.588 12212.634
KFB RKF RFB 12772.965 11691.638
FBK RKF FBR 12164.019 11777.333
KFB KRF RFB 12122.000 12442.438
BFK RKF BFR 12386.609 12970.214
BKF KRF BRF 13848.587 16703.520
KBF RKF RBF 14832.411 14118.670
KBF KRF RBF 14172.575 14668.185
BKF RKF BRF 14619.742 16115.491

param w/o F -- R=1

input param output ipo pio
KFB KR RFB 116.458 116.501
KBF KR RBF 119.778 116.487
KBF RK RBF 116.583 116.493
KFB RK RFB 116.495 116.550
FBK RK FBR 117.029 117.115
BFK KR BFR 117.116 117.057
BFK RK BFR 117.084 117.155
FBK KR FBR 117.092 117.174
FKB RK FRB 358.544 339.967
FKB KR FRB 358.303 340.043
BKF RK BRF 365.138 350.999
BKF KR BRF 365.350 351.294

param w/o F -- R=128

input param output ipo pio
KFB KR RFB 336.120 221.956
KBF KR RBF 335.887 227.228
KBF RK RBF 341.490 229.775
KFB RK RFB 342.979 230.091
BFK KR BFR 248.398 344.448
FBK KR FBR 248.870 344.915
FBK RK FBR 253.334 349.901
BFK RK BFR 255.264 351.808
FKB KR FRB 559.920 496.469
FKB RK FRB 568.355 502.368
BKF KR BRF 585.876 507.658
BKF RK BRF 592.483 513.791

param w/o F -- R=512

input param output ipo pio
KFB KR RFB 1418.012 811.970
KBF KR RBF 1446.548 814.111
BFK KR BFR 829.528 1316.020
FBK KR FBR 831.009 1311.129
KFB RK RFB 1512.291 839.328
KBF RK RBF 1511.988 840.207
BFK RK BFR 1001.170 1482.165
FBK RK FBR 1010.110 1475.463
FKB KR FRB 1316.934 1228.216
BKF KR BRF 1359.575 1237.370
FKB RK FRB 1505.308 1259.962
BKF RK BRF 1530.362 1270.497
lkct commented 1 year ago

Mixing (sum) layers -- 2-operand einsum, 4-D input, 3-D param, 3-D output

code

differ from above by

B = 128
K = 512
C = 2

F = 32
F = 288

sizes = {
    "B": B,
    "K": K,
    "F": F,
    "C": C,
}

...

    for inputs in itertools.permutations("BKFC"):
        for params in itertools.permutations("KFC"):
            inputs = "".join(inputs)
            params = "".join(params)
            outputs = inputs.replace("C", "")

result -- F=288

input param output ipo pio
FKCB FKC FKB 367.531 367.391
KFCB KFC KFB 367.505 367.414
FKCB CFK FKB 371.657 368.825
CFKB FKC FKB 370.326 369.524
FKCB KFC FKB 371.624 369.978
KFCB KCF KFB 370.246 370.936
KFCB CKF KFB 370.284 373.721
KFCB FKC KFB 371.914 370.394
CKFB KFC KFB 370.408 370.402
FKCB FCK FKB 370.740 370.582
CFKB FCK FKB 373.479 371.812
KFCB FCK KFB 373.505 372.520
FKCB CKF FKB 372.763 373.647
KFCB CFK KFB 373.185 373.923
FKCB KCF FKB 373.819 373.542
CKFB KCF KFB 373.818 373.547
CKFB CKF KFB 373.556 373.685
CFKB CFK FKB 376.040 373.648
CFKB KFC FKB 373.743 375.004
CKFB FKC KFB 375.028 378.774
CFKB KCF FKB 376.018 375.233
CFKB CKF FKB 376.886 375.875
CKFB CFK KFB 377.047 376.714
CKFB FCK KFB 377.676 376.757
KCFB KFC KFB 1540.021 810.576
FCKB FKC FKB 1540.502 811.118
KCFB FKC KFB 1547.769 813.003
FCKB KFC FKB 1546.890 813.744
KCFB CKF KFB 1579.278 813.753
FCKB CFK FKB 1586.879 814.755
KCFB CFK KFB 1567.768 815.677
FCKB KCF FKB 1556.495 815.865
KCFB FCK KFB 1545.086 815.905
KCFB KCF KFB 1543.966 815.987
FCKB CKF FKB 1550.423 817.128
FCKB FCK FKB 1554.190 817.297
FCBK FKC FBK 1827.928 1088.175
FCBK FCK FBK 1831.495 1090.527
FCBK CFK FBK 1889.879 1090.808
FCBK KFC FBK 1834.945 1091.186
FCBK CKF FBK 1837.739 1091.674
FCBK KCF FBK 1838.717 1091.918
FKBC FKC FKB 1103.236 1094.623
CFBK FKC FBK 1844.333 1094.778
CFBK FCK FBK 1843.633 1097.725
CFBK KFC FBK 1843.886 1098.803
KFBC FKC KFB 1099.041 1099.132
CFBK CFK FBK 1877.116 1099.701
KFBC CKF KFB 1151.676 1099.901
KFBC KFC KFB 1115.185 1100.159
KFBC KCF KFB 1108.606 1100.315
FKBC KCF FKB 1124.793 1100.353
FKBC CKF FKB 1114.925 1100.431
CFBK KCF FBK 1855.247 1100.536
FKBC KFC FKB 1100.970 1112.996
CFBK CKF FBK 1841.919 1101.455
FKBC FCK FKB 1102.943 1113.840
KFBC FCK KFB 1104.292 1122.477
KFBC CFK KFB 1123.885 1115.799
KBFC CKF KBF 1846.746 1118.036
KBFC CFK KBF 1813.970 1119.177
FKBC CFK FKB 1135.750 1119.635
KBFC KFC KBF 1809.347 1119.766
KCBF KFC KBF 1830.465 1122.838
KBFC KCF KBF 1814.562 1123.999
KBFC FKC KBF 1813.773 1124.754
KCBF KCF KBF 1836.964 1125.530
KCBF FKC KBF 1837.573 1126.583
KBFC FCK KBF 1815.755 1127.580
KCBF CKF KBF 1870.705 1127.629
CKBF CFK KBF 1847.998 1128.662
KCBF FCK KBF 1836.367 1128.818
KCBF CFK KBF 1840.175 1128.863
CKBF KFC KBF 1844.957 1130.018
CKBF CKF KBF 1874.794 1130.837
CKBF KCF KBF 1843.871 1132.791
CKBF FCK KBF 1843.299 1133.625
CKBF FKC KBF 1848.352 1135.878
FBKC FCK FBK 1810.536 1137.477
FBKC FKC FBK 1805.791 1137.899
FBKC KFC FBK 1856.439 1139.148
FBKC KCF FBK 1812.757 1141.217
FBKC CFK FBK 1846.279 1142.776
FBKC CKF FBK 1817.862 1144.043
KBCF KFC KBF 1835.764 1156.568
KBCF KCF KBF 1841.768 1157.567
KBCF FKC KBF 1854.994 1160.787
KBCF CFK KBF 1837.696 1162.510
KBCF CKF KBF 1866.511 1162.532
KBCF FCK KBF 1836.835 1162.555
FBCK FKC FBK 1830.111 1178.649
FBCK FCK FBK 1836.005 1180.750
FBCK KFC FBK 1853.291 1181.822
FBCK KCF FBK 1837.737 1181.836
FBCK CFK FBK 1868.068 1182.880
FBCK CKF FBK 1840.923 1183.692
BKFC KFC BKF 2691.465 2688.253
BFKC FKC BFK 2694.193 2695.982
BKFC KCF BKF 2695.214 2697.435
BFKC CFK BFK 2741.626 2696.012
BFKC KFC BFK 2704.749 2700.642
BKFC FKC BKF 2701.903 2704.405
BFKC CKF BFK 2702.773 2704.576
BFKC FCK BFK 2703.092 2703.769
BFKC KCF BFK 2704.245 2705.473
BKFC FCK BKF 2704.247 2707.089
BKFC CKF BKF 2730.353 2708.528
BKFC CFK BKF 2719.952 2709.968
CBKF KFC BKF 4263.530 3587.400
CBFK FKC BFK 4238.395 3591.081
CBKF FKC BKF 4269.617 3592.677
CBKF KCF BKF 4258.191 3593.320
CBFK KFC BFK 4248.524 3594.745
CBKF CFK BKF 4258.850 3594.814
CBKF CKF BKF 4283.232 3595.522
CBFK CKF BFK 4249.298 3596.531
CBFK CFK BFK 4304.656 3596.606
CBFK FCK BFK 4257.009 3597.092
CBKF FCK BKF 4251.767 3597.939
CBFK KCF BFK 4248.681 3599.644
BKCF FKC BKF 3736.378 3949.921
BKCF FCK BKF 3738.488 3961.672
BKCF CFK BKF 3740.531 3956.702
BFCK KFC BFK 3747.073 3967.060
BFCK KCF BFK 3749.707 3968.432
BFCK FCK BFK 3752.576 3970.594
BFCK FKC BFK 3753.589 3955.989
BKCF KCF BKF 3753.634 3955.264
BKCF KFC BKF 3757.144 3971.041
BFCK CKF BFK 3762.168 3967.541
BKCF CKF BKF 3772.368 3957.740
BFCK CFK BFK 3793.712 3971.665
BCFK FCK BFK 4335.131 3994.408
BCKF KFC BKF 4330.439 3994.670
BCKF FKC BKF 4336.629 3995.968
BCFK KFC BFK 4329.605 3998.567
BCFK KCF BFK 4329.171 3999.583
BCFK CFK BFK 4368.991 4003.686
BCKF FCK BKF 4352.477 4005.444
BCFK FKC BFK 4334.350 4009.904
BCKF CFK BKF 4333.976 4010.229
BCKF CKF BKF 4371.505 4011.900
BCKF KCF BKF 4331.937 4013.636
BCFK CKF BFK 4335.028 4013.674

result -- F=32

input param output ipo pio
CFKB FKC FKB 79.598 79.153
CKFB KFC KFB 79.566 81.817
FKCB FKC FKB 79.955 97.287
CKFB CKF KFB 80.210 92.526
KFCB KFC KFB 80.294 80.422
CFKB CFK FKB 80.497 99.437
FKCB CFK FKB 80.926 91.666
KFCB CKF KFB 82.341 90.974
FKCB FCK FKB 92.289 92.344
CKFB FCK KFB 92.920 92.574
KFCB FCK KFB 93.374 92.578
CKFB FKC KFB 92.658 94.965
KFCB CFK KFB 93.099 92.690
CKFB KCF KFB 92.756 110.558
CKFB CFK KFB 92.970 92.800
KFCB FKC KFB 93.085 92.824
KFCB KCF KFB 92.919 92.848
CFKB KCF FKB 94.191 92.948
CFKB KFC FKB 93.094 93.425
CFKB CKF FKB 94.684 93.358
CFKB FCK FKB 93.367 93.559
FKCB KFC FKB 93.462 93.947
FKCB CKF FKB 94.052 94.461
FKCB KCF FKB 94.733 94.473
KCFB KFC KFB 211.477 130.044
FCKB FKC FKB 210.620 130.102
KCFB CKF KFB 215.384 132.373
FCKB CFK FKB 215.085 133.753
KCFB KCF KFB 213.093 137.435
KCFB CFK KFB 214.065 137.480
FCKB KFC FKB 213.780 137.704
KCFB FCK KFB 213.443 137.870
KCFB FKC KFB 213.155 137.957
FCKB FCK FKB 214.328 138.155
FCKB KCF FKB 214.054 138.580
FCKB CKF FKB 213.860 138.975
KBFC KFC KBF 238.672 157.704
KBCF KFC KBF 240.246 158.713
KCBF KFC KBF 255.124 158.941
CKBF KFC KBF 241.658 160.352
FKBC FKC FKB 161.263 160.596
KFBC KFC KFB 160.904 161.044
KBFC CKF KBF 243.652 161.269
KBCF CKF KBF 246.309 162.309
FCBK FKC FBK 243.916 162.420
FKBC CFK FKB 162.771 172.927
CKBF CKF KBF 247.095 162.829
KCBF CKF KBF 246.178 163.045
KFBC CKF KFB 163.281 173.811
CFBK FKC FBK 246.422 164.112
FBKC FKC FBK 242.724 165.993
CFBK CFK FBK 250.655 166.025
FCBK CFK FBK 249.122 166.101
KCBF FKC KBF 244.671 166.421
KBCF CFK KBF 243.866 166.584
KBFC FKC KBF 242.311 166.680
KCBF CFK KBF 244.461 166.827
KCBF KCF KBF 244.174 166.907
KCBF FCK KBF 244.131 166.992
KBFC KCF KBF 241.301 167.124
KBCF FCK KBF 243.406 167.519
CKBF KCF KBF 244.784 167.618
KBCF KCF KBF 243.740 167.685
KBFC FCK KBF 241.152 167.970
KBFC CFK KBF 241.688 168.077
CKBF FCK KBF 244.305 168.107
FBCK FKC FBK 244.686 168.707
KBCF FKC KBF 244.034 169.067
CKBF CFK KBF 245.370 169.176
FBKC CFK FBK 247.511 169.392
BKCF KFC BKF 248.457 169.842
BFCK FKC BFK 248.430 169.880
CKBF FKC KBF 244.315 169.925
FCBK KCF FBK 247.931 170.809
FCBK CKF FBK 256.509 170.834
CFBK KCF FBK 249.278 171.005
CFBK KFC FBK 249.086 171.034
FBCK CFK FBK 249.342 171.165
CFBK FCK FBK 249.204 171.501
CFBK CKF FBK 248.849 172.072
FCBK KFC FBK 248.645 172.564
FBKC KCF FBK 268.819 172.959
FKBC FCK FKB 173.119 174.640
FBKC KFC FBK 245.996 173.209
FBKC FCK FBK 245.642 173.410
FBKC CKF FBK 246.213 173.414
KFBC FKC KFB 173.680 175.566
FKBC KCF FKB 173.894 174.260
BKCF CKF BKF 251.848 173.899
KFBC KCF KFB 173.975 176.086
FKBC KFC FKB 173.998 174.837
KFBC FCK KFB 174.170 174.683
KFBC CFK KFB 174.195 175.851
FKBC CKF FKB 174.294 175.056
CBFK FKC BFK 258.601 174.520
BFCK CFK BFK 253.164 174.660
BCFK FKC BFK 259.629 174.748
CBKF KFC BKF 258.871 175.114
BCKF KFC BKF 259.257 175.316
BCFK CFK BFK 263.828 176.710
FBCK CKF FBK 247.714 176.801
FBCK KCF FBK 247.170 177.100
BCKF CKF BKF 265.588 177.627
FBCK FCK FBK 247.046 178.085
FBCK KFC FBK 247.356 178.092
CBFK CFK BFK 264.375 178.169
BKCF CFK BKF 250.636 178.263
BKCF FCK BKF 250.638 178.286
CBKF CKF BKF 263.181 178.454
BFCK KCF BFK 252.320 178.705
BKCF KCF BKF 250.549 178.750
BFCK KFC BFK 251.316 179.074
BFCK CKF BFK 251.315 179.160
BKCF FKC BKF 250.134 180.515
CBFK KFC BFK 266.976 182.966
FCBK FCK FBK 247.765 183.024
BCFK KFC BFK 268.268 183.291
BCKF KCF BKF 267.738 183.309
CBFK KCF BFK 267.570 183.435
BCKF FKC BKF 269.241 183.463
BCKF CFK BKF 268.594 183.560
CBKF CFK BKF 267.082 183.668
CBKF FKC BKF 267.374 183.678
CBKF FCK BKF 266.321 183.830
BCFK CKF BFK 269.935 183.875
CBFK CKF BFK 275.909 184.310
CBKF KCF BKF 267.492 184.363
BCFK KCF BFK 267.446 184.509
BCFK FCK BFK 268.312 184.614
CBFK FCK BFK 266.665 184.909
BCKF FCK BKF 267.597 185.701
BFCK FCK BFK 250.912 189.132
BFKC FKC BFK 207.639 207.256
BKFC KFC BKF 208.248 275.361
BFKC CFK BFK 209.477 217.516
BKFC CKF BKF 209.764 219.431
BFKC FCK BFK 218.109 219.913
BKFC FKC BKF 218.851 219.514
BFKC CKF BFK 219.123 218.918
BFKC KCF BFK 218.981 219.248
BKFC CFK BKF 220.066 219.263
BFKC KFC BFK 220.841 219.490
BKFC KCF BKF 220.788 219.631
BKFC FCK BKF 220.577 220.304
lkct commented 1 year ago

Tucker (EinNet) layers -- 3-operand einsum, 2x 3-D input, 4-D param, 3-D output

K=512 causes OOM, 128 is too slow

code

differ from CP by

K = 64

sizes = {
    "B": B,
    "K": K,
    "F": F,
    "I": K,
    "J": K,
}

...

def bench(inputs: str, inputsj: str, params: str, outputs: str, order: str) -> float:
    input_data = torch.rand([sizes[x] for x in inputs], device=device)
    inputj_data = torch.rand([sizes[x] for x in inputsj], device=device)
    param_data = torch.rand([sizes[x] for x in params], device=device)

    equation_ijpo = f"{inputs},{inputsj},{params}->{outputs}"  # jipo should be the same
    equation_ipjo = f"{inputs},{params},{inputsj}->{outputs}"  # jpio should be the same
    equation_pijo = f"{params},{inputs},{inputsj}->{outputs}"  # pjio should be the same

    def proc_ijpo() -> Tensor:
        result = torch.einsum(equation_ijpo, input_data, inputj_data, param_data)
        return result.contiguous()  # penalize uncontiguous

    def proc_ipjo() -> Tensor:
        result = torch.einsum(equation_ipjo, input_data, param_data, inputj_data)
        return result.contiguous()  # penalize uncontiguous

    def proc_pijo() -> Tensor:
        result = torch.einsum(equation_pijo, param_data, input_data, inputj_data)
        return result.contiguous()  # penalize uncontiguous

    proc = proc_ijpo if order == "ijpo" else proc_ipjo if order == "ipjo" else proc_pijo
    _, t = timer(proc)
    return t

...

    for inputs in itertools.permutations("BIF"):
        for params in itertools.permutations("IJKF"):
            inputs = "".join(inputs)
            inputsj = inputs.replace("I", "J")
            params = "".join(params)
            outputs = inputs.replace("I", "K")
            ans.append(["i", inputs, inputsj, "p", params, "o", outputs])
            for order in ("ijpo", "ipjo", "pijo"):

...

    ans.sort(key=lambda x: min(x[-1], x[-3], x[-5]))

result

input_i input_j param output ijpo ipjo pijo
FBI FBJ IFKJ FBK 3908.628 2377.637 4540.433
FBI FBJ FIKJ FBK 3836.590 2412.078 4520.354
IFB JFB FIKJ KFB 4609.518 2422.983 4597.182
FIB FJB FIKJ FKB 3870.240 2430.586 4614.175
IFB JFB IFKJ KFB 4684.945 2434.715 4612.133
FIB FJB IFKJ FKB 3942.264 2439.313 4604.340
FBI FBJ FKJI FBK 4219.890 2495.123 4641.883
IFB JFB FKJI KFB 4953.674 2554.761 4723.499
FIB FJB FKJI FKB 4252.273 2559.678 4692.171
FBI FBJ KJFI FBK 8490.646 2727.708 4868.239
IFB JFB KJFI KFB 9283.930 2772.792 4922.008
FIB FJB KJFI FKB 8507.524 2781.015 4916.187
FBI FBJ FIJK FBK 2996.839 3284.207 5895.220
FBI FBJ IJFK FBK 3005.752 3358.581 10215.086
FIB FJB IJFK FKB 3033.671 3426.286 10236.571
BFI BFJ FIJK BFK 3035.726 5004.268 15060.040
FIB FJB FIJK FKB 3036.479 3360.912 5961.381
FBI FBJ FKIJ FBK 3055.769 3299.696 5487.676
FIB FJB FKIJ FKB 3058.600 3360.504 5559.431
BFI BFJ IJFK BFK 3068.995 5082.056 19591.741
FBI FBJ KFIJ FBK 3075.738 3310.673 5573.188
FIB FJB KFIJ FKB 3099.481 3370.923 5636.986
BFI BFJ FKIJ BFK 3113.800 5056.174 14790.037
FBI FBJ FJIK FBK 3803.535 3287.835 5724.976
FBI FBJ KIFJ FBK 4724.184 3337.369 5569.674
IFB JFB FKIJ KFB 3797.939 3352.448 5561.146
IFB JFB FJIK KFB 4588.908 3352.739 5794.903
IFB JFB FIJK KFB 3754.884 3353.990 5969.906
FBI FBJ IFJK FBK 3834.165 3356.635 9768.044
FIB FJB FJIK FKB 3835.691 3363.097 5775.134
IFB JFB KFIJ KFB 3814.360 3364.032 5635.572
FBI FBJ IKFJ FBK 3912.668 3385.963 6434.167
IFB JFB IKFJ KFB 4697.286 3397.757 6445.800
IFB JFB KIFJ KFB 5409.265 3398.406 5643.911
FIB FJB IKFJ FKB 3938.763 3404.906 6463.434
IFB JFB IFJK KFB 4616.902 3419.826 9821.942
FIB FJB IFJK FKB 3848.678 3420.680 9809.559
IFB JFB IJFK KFB 3780.082 3421.240 10328.697
FIB FJB KIFJ FKB 4753.905 3475.389 5647.986
FBI FBJ FJKI FBK 4061.292 3592.565 5481.301
BFI BFJ KFIJ BFK 3634.948 5039.832 14910.924
FIB FJB FJKI FKB 4106.275 3677.640 5539.523
IFB JFB FJKI KFB 4804.897 3680.428 5553.464
FBI FBJ KFJI FBK 8011.818 3825.199 5505.485
FBI FBJ JFIK FBK 3838.637 4056.771 6103.444
FIB FJB JFIK FKB 3856.767 4117.399 6184.915
BFI BFJ FJIK BFK 3858.775 5006.212 14866.041
FBI FBJ JIFK FBK 3864.836 4140.201 6516.918
BFI BFJ JFIK BFK 3874.823 5709.612 15325.226
BFI BFJ FIKJ BFK 3880.999 4171.772 13823.464
BFI BFJ IFJK BFK 3886.648 5081.302 19082.015
FIB FJB JIFK FKB 3895.076 4194.846 6578.682
BFI BFJ JIFK BFK 3914.844 5917.165 15772.189
FIB FJB KFJI FKB 8009.399 3919.569 5574.401
IFB JFB KFJI KFB 8816.070 3924.129 5578.458
BFI BFJ IKFJ BFK 3961.212 5080.253 15505.795
BFI BFJ IFKJ BFK 3965.693 4144.332 13822.074
IFB JFB JFIK KFB 4610.678 4119.573 6187.474
BFI BFJ FJKI BFK 4132.121 5280.227 14866.309
IBF JBF FIKJ KBF 27502.828 4217.491 13988.614
IBF JBF IFKJ KBF 27565.654 4240.776 13989.740
BIF BJF IFKJ BKF 25198.791 4242.847 14001.018
BFI BFJ FKJI BFK 4274.590 4245.031 13959.578
BIF BJF FIKJ BKF 25149.931 4248.589 13990.308
IFB JFB JIFK KFB 4692.126 4256.036 6582.374
IBF JBF FKJI KBF 27787.650 4345.416 14047.371
BIF BJF FKJI BKF 25354.145 4350.012 14065.164
FBI FBJ JFKI FBK 4452.561 7527.268 5503.660
BFI BFJ KJFI BFK 8576.348 4476.279 14196.162
FIB FJB JFKI FKB 4492.634 7585.673 5571.520
BFI BFJ JFKI BFK 4509.255 9251.932 14863.500
IBF JBF KJFI KBF 32316.516 4578.232 14280.320
BIF BJF KJFI BKF 29864.192 4589.368 14318.623
BFI BFJ KIFJ BFK 4798.598 5087.361 14921.010
FBI FBJ JKFI FBK 4814.459 8022.303 5563.618
FIB FJB JKFI FKB 4849.267 8087.695 5616.444
BFI BFJ JKFI BFK 4875.337 9739.322 14910.316
IBF JBF FIJK KBF 26759.803 5108.515 15160.152
IBF JBF FJIK KBF 27506.458 5112.371 15014.336
IBF JBF FKIJ KBF 26837.963 5114.460 14877.931
BIF BJF FIJK BKF 24192.515 5115.876 15173.191
BIF BJF FJIK BKF 25152.414 5120.059 15036.619
IBF JBF KFIJ KBF 26863.712 5124.499 14989.815
BIF BJF FKIJ BKF 24274.233 5129.827 14953.184
IBF JBF KIFJ KBF 28279.598 5155.905 15007.012
BIF BJF KFIJ BKF 24265.052 5158.428 14946.707
IBF JBF IKFJ KBF 27563.420 5160.645 15704.193
BIF BJF KIFJ BKF 25875.848 5170.673 15018.984
BIF BJF IFJK BKF 25144.943 5171.858 19147.608
BIF BJF IKFJ BKF 25176.313 5175.582 15735.124
IBF JBF IFJK KBF 27518.212 5184.266 19173.475
IFB JFB JFKI KFB 5187.517 7563.389 5566.586
BIF BJF IJFK BKF 24138.406 5189.429 19573.731
IBF JBF IJFK KBF 26763.469 5189.476 19678.936
IBF JBF FJKI KBF 27574.461 5417.717 14941.064
BIF BJF FJKI BKF 25214.373 5440.717 14966.369
BFI BFJ KFJI BFK 8066.536 5535.890 14883.697
IFB JFB JKFI KFB 5571.748 8084.470 5613.728
IBF JBF KFJI KBF 31738.359 5644.784 14951.942
BIF BJF KFJI BKF 29385.698 5670.004 14973.997
IBF JBF JFIK KBF 27594.927 5830.726 15416.097
BIF BJF JFIK BKF 25172.149 5888.308 15422.273
IBF JBF JIFK KBF 27551.375 5945.994 15865.789
BIF BJF JIFK BKF 25206.292 6029.212 15922.171
BIF BJF JFKI BKF 25612.969 9409.653 14967.835
IBF JBF JFKI KBF 27972.215 9420.192 14950.724
IBF JBF JKFI KBF 28428.455 9893.919 14986.498
BIF BJF JKFI BKF 26063.977 9935.555 15000.447
FBI FBJ KIJF FBK 14408.882 13102.867 16097.338
FBI FBJ IKJF FBK 14404.375 13131.181 16037.176
IFB JFB KIJF KFB 15232.959 13169.579 16141.654
FIB FJB KIJF FKB 14401.401 13185.572 16049.590
IFB JFB IKJF KFB 15242.951 13200.043 16086.763
FIB FJB IKJF FKB 14377.524 13210.020 16100.169
FBI FBJ JIKF FBK 13402.432 14042.264 16180.082
FIB FJB IJKF FKB 13423.102 14013.476 16088.511
FIB FJB JIKF FKB 13487.001 14122.503 16237.219
FBI FBJ IJKF FBK 13543.469 14039.374 16032.647
BFI BFJ JIKF BFK 13552.916 15769.723 25733.966
BFI BFJ IJKF BFK 13600.586 15665.860 25701.877
IFB JFB KJIF KFB 15255.800 13976.967 15237.502
FBI FBJ KJIF FBK 14341.887 14003.328 15188.944
FIB FJB KJIF FKB 14406.042 14078.892 15246.011
FBI FBJ JKIF FBK 14426.273 14091.921 15197.412
IFB JFB JIKF KFB 14328.805 14104.929 16224.251
IFB JFB IJKF KFB 14390.491 14105.099 16204.005
IFB JFB JKIF KFB 15256.116 14158.904 15214.613
FIB FJB JKIF FKB 14316.826 14171.641 15126.375
BFI BFJ KIJF BFK 14451.512 14831.610 25648.642
BFI BFJ IKJF BFK 14452.288 14849.012 25409.622
BFI BFJ JKIF BFK 14469.899 15705.037 24719.888
BFI BFJ KJIF BFK 14470.129 15723.971 24585.098
IBF JBF KIJF KBF 38229.830 14972.211 25722.334
BIF BJF KIJF BKF 35864.951 14990.372 25745.150
IBF JBF IKJF KBF 38231.584 14996.291 25470.399
BIF BJF IKJF BKF 35800.384 15007.770 25687.052
IBF JBF KJIF KBF 38252.489 15779.133 24648.695
BIF BJF KJIF BKF 35867.406 15791.552 24823.959
IBF JBF IJKF KBF 37369.621 15907.600 25611.660
IBF JBF JIKF KBF 37326.194 15908.640 25807.500
BIF BJF IJKF BKF 34922.554 15912.334 25612.394
BIF BJF JIKF BKF 34942.249 15927.255 25822.737
IBF JBF JKIF KBF 38252.931 15963.755 24788.428
BIF BJF JKIF BKF 35872.001 15972.883 24653.693