Open lkct opened 1 year ago
1-D indexing for simple folding, 2-D indexing for mixing (include a dimension for input)
import itertools
import subprocess
import sys
import time
from typing import List
import torch
from torch import Tensor
from benchmark.utils.gpu_benchmark import timer
device = "cuda"
B = 128
K = 512
F = 784
f = 288
sizes = {
"B": B,
"K": K,
"F": F,
"f": f,
"2": 2,
}
N = 200
def bench(inputs: List[str], idx_shape: str, dim: int) -> float:
input_data = [torch.rand([sizes[x] for x in y], device=device) for y in inputs]
idx_data = torch.randint(f, [sizes[x] for x in idx_shape], device=device)
idx = [idx_data if i == dim else slice(None) for i in range(len(inputs[0]))]
def proc() -> Tensor:
result = torch.cat(input_data, dim=dim)
result = result.__getitem__(idx)
return result
_, t = timer(proc)
return t
def run(inputs: str, idx_shape: str) -> None:
if "F" in inputs:
dim = inputs.index("F")
args = ([inputs], idx_shape, dim)
else:
dim = inputs.index("f")
args = ([inputs, inputs.replace("f", "F")], idx_shape, dim)
for _ in range(100): # long warm-up is essential?
bench(*args)
s = 0.0
for _ in range(N):
s += bench(*args)
print(f"{s / N * 1000:9.3f}", end="")
def main() -> None:
ans = []
for inputs in itertools.permutations("BKF"):
for idx_shape in ("f",):
inputs = "".join(inputs)
result = subprocess.run(
["python", "bench.py", inputs, idx_shape],
capture_output=True,
check=False,
text=True,
)
assert not result.returncode, result.stderr
ans.append(["i", inputs, "idx", idx_shape, result.stdout])
time.sleep(1)
print(ans[-1], file=sys.stderr)
ans.sort(key=lambda x: x[-1])
for item in ans:
print(*item)
if __name__ == "__main__":
if len(sys.argv) == 1:
main()
else:
run(*sys.argv[1:])
input | time |
---|---|
BFK | 793.532 |
KFB | 794.392 |
FKB | 831.200 |
FBK | 831.399 |
KBF | 844.904 |
BKF | 845.616 |
code differ by
for inputs in itertools.permutations("BKf"):
input | time |
---|---|
KFB | 1041.293 |
BFK | 1043.517 |
KBF | 1090.914 |
BKF | 1091.461 |
FKB | 1104.790 |
FBK | 1109.793 |
code differ by
def proc() -> Tensor:
result = input_data[0]
result = result.__getitem__(idx)
return result
...
for idx_shape in ("2f", "f2"):
input | idx | time |
---|---|---|
KFB | 2f | 322.121 |
KFB | f2 | 322.174 |
BFK | 2f | 323.775 |
BFK | f2 | 324.232 |
KBF | f2 | 341.624 |
BKF | 2f | 341.690 |
KBF | 2f | 341.697 |
BKF | f2 | 341.883 |
FBK | 2f | 440.158 |
FKB | 2f | 440.267 |
FBK | f2 | 440.390 |
FKB | f2 | 440.422 |
2-D param for param shares among folding, 3-D for not
import itertools
import subprocess
import sys
import time
import torch
from torch import Tensor
from benchmark.utils.gpu_benchmark import timer
device = "cuda"
B = 128
K = 512
F = 288
R = 1
R = 128
R = 512
sizes = {
"B": B,
"K": K,
"F": F,
"R": R,
}
N = 200
def bench(inputs: str, params: str, outputs: str, order: str) -> float:
input_data = torch.rand([sizes[x] for x in inputs], device=device)
param_data = torch.rand([sizes[x] for x in params], device=device)
equation_ipo = f"{inputs},{params}->{outputs}"
equation_pio = f"{params},{inputs}->{outputs}"
def proc_ipo() -> Tensor:
result = torch.einsum(equation_ipo, input_data, param_data)
return result.contiguous() # penalize uncontiguous
def proc_pio() -> Tensor:
result = torch.einsum(equation_pio, param_data, input_data)
return result.contiguous() # penalize uncontiguous
proc = proc_ipo if order == "ipo" else proc_pio
_, t = timer(proc)
return t
def run(inputs: str, params: str, outputs: str, order: str) -> None:
args = (inputs, params, outputs, order)
for _ in range(100): # long warm-up is essential?
bench(*args)
s = 0.0
for _ in range(N):
s += bench(*args)
print(order, f"{s / N * 1000:9.3f}", sep=",", end=",")
def main() -> None:
ans = []
for inputs in itertools.permutations("BKF"):
for params in itertools.permutations("KRF"): # "KR" for w/o F
inputs = "".join(inputs)
params = "".join(params)
outputs = inputs.replace("K", "R")
ans.append(["i", inputs, "p", params, "o", outputs])
for order in ("ipo", "pio"):
result = subprocess.run(
["python", "bench.py", inputs, params, outputs, order],
capture_output=True,
check=True,
text=True,
)
ans[-1].extend(result.stdout.split(",")[:-1])
time.sleep(1)
print(ans[-1], file=sys.stderr)
ans.sort(key=lambda x: min(x[-1], x[-3]))
for item in ans:
print(*item)
if __name__ == "__main__":
if len(sys.argv) == 1:
main()
else:
run(*sys.argv[1:])
F
-- R
=1input | param | output | ipo | pio |
---|---|---|---|---|
FBK | KRF | FBR | 113.788 | 119.096 |
FBK | KFR | FBR | 113.789 | 119.216 |
FBK | RKF | FBR | 113.912 | 119.255 |
FBK | FRK | FBR | 114.691 | 114.610 |
FBK | RFK | FBR | 114.623 | 114.810 |
FBK | FKR | FBR | 114.643 | 114.854 |
FKB | FKR | FRB | 117.210 | 117.104 |
FKB | FRK | FRB | 117.171 | 117.119 |
FKB | RFK | FRB | 117.231 | 117.146 |
FKB | RKF | FRB | 117.993 | 120.959 |
FKB | KRF | FRB | 118.009 | 120.986 |
FKB | KFR | FRB | 118.018 | 121.041 |
KFB | RFK | RFB | 118.541 | 118.592 |
KFB | FKR | RFB | 118.595 | 118.558 |
KFB | FRK | RFB | 118.566 | 118.568 |
KFB | KFR | RFB | 119.701 | 122.484 |
KFB | KRF | RFB | 119.773 | 122.480 |
KFB | RKF | RFB | 119.831 | 122.524 |
BFK | FKR | BFR | 119.897 | 119.930 |
BFK | RFK | BFR | 119.958 | 121.007 |
BFK | FRK | BFR | 119.960 | 120.003 |
BFK | RKF | BFR | 127.243 | 124.769 |
BFK | KFR | BFR | 127.366 | 124.811 |
BFK | KRF | BFR | 127.302 | 124.972 |
BKF | KFR | BRF | 1801.846 | 1957.580 |
KBF | RFK | RBF | 1955.285 | 1803.298 |
BKF | RFK | BRF | 1803.699 | 1947.905 |
KBF | FRK | RBF | 1955.196 | 1804.161 |
BKF | RKF | BRF | 1804.927 | 1959.471 |
BKF | FRK | BRF | 1805.174 | 1954.135 |
BKF | KRF | BRF | 1805.285 | 1951.911 |
BKF | FKR | BRF | 1806.474 | 1946.075 |
KBF | KFR | RBF | 1951.512 | 1809.683 |
KBF | RKF | RBF | 1953.504 | 1810.281 |
KBF | FKR | RBF | 1952.605 | 1811.003 |
KBF | KRF | RBF | 1953.056 | 1831.348 |
F
-- R
=128input | param | output | ipo | pio |
---|---|---|---|---|
FKB | FKR | FRB | 332.943 | 271.832 |
FKB | KFR | FRB | 334.048 | 273.649 |
FKB | FRK | FRB | 350.555 | 290.169 |
FBK | FKR | FBR | 290.291 | 350.455 |
FBK | KFR | FBR | 291.597 | 353.179 |
FKB | RFK | FRB | 354.897 | 294.839 |
KFB | KFR | RFB | 367.332 | 329.672 |
KFB | FKR | RFB | 367.710 | 329.767 |
KFB | FRK | RFB | 385.651 | 347.770 |
BFK | FKR | BFR | 349.022 | 388.739 |
KFB | RFK | RFB | 388.376 | 349.378 |
BFK | KFR | BFR | 351.018 | 388.244 |
FBK | FRK | FBR | 403.333 | 464.001 |
FBK | RFK | FBR | 425.401 | 486.442 |
BFK | FRK | BFR | 476.782 | 515.319 |
BFK | RFK | BFR | 530.967 | 568.218 |
FBK | KRF | FBR | 1985.983 | 2309.411 |
FKB | RKF | FRB | 2168.719 | 1988.884 |
FKB | KRF | FRB | 2009.858 | 2123.959 |
KFB | RKF | RFB | 2200.865 | 2034.394 |
KBF | KFR | RBF | 2265.354 | 2035.632 |
KBF | FKR | RBF | 2265.921 | 2037.660 |
BFK | KRF | BFR | 2049.831 | 2367.700 |
BKF | KFR | BRF | 2051.550 | 2241.513 |
KBF | FRK | RBF | 2378.948 | 2053.770 |
BKF | FKR | BRF | 2054.935 | 2239.223 |
KBF | RFK | RBF | 2391.599 | 2058.657 |
KFB | KRF | RFB | 2065.316 | 2191.561 |
FBK | RKF | FBR | 2121.255 | 2163.598 |
BKF | FRK | BRF | 2184.023 | 2255.856 |
BFK | RKF | BFR | 2188.250 | 2199.440 |
BKF | RFK | BRF | 2204.185 | 2259.588 |
KBF | RKF | RBF | 4100.396 | 3748.783 |
BKF | KRF | BRF | 3752.156 | 4108.266 |
BKF | RKF | BRF | 3912.224 | 3961.133 |
KBF | KRF | RBF | 3948.498 | 3917.926 |
F
-- R
=512input | param | output | ipo | pio |
---|---|---|---|---|
FKB | KFR | FRB | 1187.450 | 943.467 |
FKB | FKR | FRB | 1180.317 | 948.735 |
FBK | FKR | FBR | 964.116 | 1225.254 |
FBK | KFR | FBR | 964.148 | 1228.849 |
FKB | FRK | FRB | 1247.073 | 1004.016 |
FKB | RFK | FRB | 1278.828 | 1028.722 |
FBK | FRK | FBR | 1054.420 | 1308.914 |
FBK | RFK | FBR | 1093.861 | 1352.954 |
KFB | FKR | RFB | 1532.598 | 1152.051 |
KFB | KFR | RFB | 1533.872 | 1156.363 |
BFK | KFR | BFR | 1168.978 | 2314.418 |
BFK | FKR | BFR | 1172.014 | 2332.021 |
KFB | FRK | RFB | 1565.030 | 1215.263 |
KFB | RFK | RFB | 1594.081 | 1252.846 |
BFK | FRK | BFR | 1264.582 | 2425.236 |
BFK | RFK | BFR | 1719.074 | 2866.844 |
BKF | FKR | BRF | 3362.928 | 5286.904 |
BKF | KFR | BRF | 3364.260 | 5370.066 |
KBF | KFR | RBF | 3595.469 | 3372.718 |
KBF | FKR | RBF | 3601.643 | 3379.303 |
KBF | FRK | RBF | 3699.549 | 3445.736 |
KBF | RFK | RBF | 3749.442 | 3460.732 |
BKF | FRK | BRF | 3473.244 | 5420.293 |
BKF | RFK | BRF | 3512.626 | 5408.833 |
FBK | KRF | FBR | 11434.301 | 12477.709 |
FKB | RKF | FRB | 12378.257 | 11467.728 |
BFK | KRF | BFR | 11633.416 | 13637.055 |
FKB | KRF | FRB | 11635.588 | 12212.634 |
KFB | RKF | RFB | 12772.965 | 11691.638 |
FBK | RKF | FBR | 12164.019 | 11777.333 |
KFB | KRF | RFB | 12122.000 | 12442.438 |
BFK | RKF | BFR | 12386.609 | 12970.214 |
BKF | KRF | BRF | 13848.587 | 16703.520 |
KBF | RKF | RBF | 14832.411 | 14118.670 |
KBF | KRF | RBF | 14172.575 | 14668.185 |
BKF | RKF | BRF | 14619.742 | 16115.491 |
F
-- R
=1input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 116.458 | 116.501 |
KBF | KR | RBF | 119.778 | 116.487 |
KBF | RK | RBF | 116.583 | 116.493 |
KFB | RK | RFB | 116.495 | 116.550 |
FBK | RK | FBR | 117.029 | 117.115 |
BFK | KR | BFR | 117.116 | 117.057 |
BFK | RK | BFR | 117.084 | 117.155 |
FBK | KR | FBR | 117.092 | 117.174 |
FKB | RK | FRB | 358.544 | 339.967 |
FKB | KR | FRB | 358.303 | 340.043 |
BKF | RK | BRF | 365.138 | 350.999 |
BKF | KR | BRF | 365.350 | 351.294 |
F
-- R
=128input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 336.120 | 221.956 |
KBF | KR | RBF | 335.887 | 227.228 |
KBF | RK | RBF | 341.490 | 229.775 |
KFB | RK | RFB | 342.979 | 230.091 |
BFK | KR | BFR | 248.398 | 344.448 |
FBK | KR | FBR | 248.870 | 344.915 |
FBK | RK | FBR | 253.334 | 349.901 |
BFK | RK | BFR | 255.264 | 351.808 |
FKB | KR | FRB | 559.920 | 496.469 |
FKB | RK | FRB | 568.355 | 502.368 |
BKF | KR | BRF | 585.876 | 507.658 |
BKF | RK | BRF | 592.483 | 513.791 |
F
-- R
=512input | param | output | ipo | pio |
---|---|---|---|---|
KFB | KR | RFB | 1418.012 | 811.970 |
KBF | KR | RBF | 1446.548 | 814.111 |
BFK | KR | BFR | 829.528 | 1316.020 |
FBK | KR | FBR | 831.009 | 1311.129 |
KFB | RK | RFB | 1512.291 | 839.328 |
KBF | RK | RBF | 1511.988 | 840.207 |
BFK | RK | BFR | 1001.170 | 1482.165 |
FBK | RK | FBR | 1010.110 | 1475.463 |
FKB | KR | FRB | 1316.934 | 1228.216 |
BKF | KR | BRF | 1359.575 | 1237.370 |
FKB | RK | FRB | 1505.308 | 1259.962 |
BKF | RK | BRF | 1530.362 | 1270.497 |
differ from above by
B = 128
K = 512
C = 2
F = 32
F = 288
sizes = {
"B": B,
"K": K,
"F": F,
"C": C,
}
...
for inputs in itertools.permutations("BKFC"):
for params in itertools.permutations("KFC"):
inputs = "".join(inputs)
params = "".join(params)
outputs = inputs.replace("C", "")
F
=288input | param | output | ipo | pio |
---|---|---|---|---|
FKCB | FKC | FKB | 367.531 | 367.391 |
KFCB | KFC | KFB | 367.505 | 367.414 |
FKCB | CFK | FKB | 371.657 | 368.825 |
CFKB | FKC | FKB | 370.326 | 369.524 |
FKCB | KFC | FKB | 371.624 | 369.978 |
KFCB | KCF | KFB | 370.246 | 370.936 |
KFCB | CKF | KFB | 370.284 | 373.721 |
KFCB | FKC | KFB | 371.914 | 370.394 |
CKFB | KFC | KFB | 370.408 | 370.402 |
FKCB | FCK | FKB | 370.740 | 370.582 |
CFKB | FCK | FKB | 373.479 | 371.812 |
KFCB | FCK | KFB | 373.505 | 372.520 |
FKCB | CKF | FKB | 372.763 | 373.647 |
KFCB | CFK | KFB | 373.185 | 373.923 |
FKCB | KCF | FKB | 373.819 | 373.542 |
CKFB | KCF | KFB | 373.818 | 373.547 |
CKFB | CKF | KFB | 373.556 | 373.685 |
CFKB | CFK | FKB | 376.040 | 373.648 |
CFKB | KFC | FKB | 373.743 | 375.004 |
CKFB | FKC | KFB | 375.028 | 378.774 |
CFKB | KCF | FKB | 376.018 | 375.233 |
CFKB | CKF | FKB | 376.886 | 375.875 |
CKFB | CFK | KFB | 377.047 | 376.714 |
CKFB | FCK | KFB | 377.676 | 376.757 |
KCFB | KFC | KFB | 1540.021 | 810.576 |
FCKB | FKC | FKB | 1540.502 | 811.118 |
KCFB | FKC | KFB | 1547.769 | 813.003 |
FCKB | KFC | FKB | 1546.890 | 813.744 |
KCFB | CKF | KFB | 1579.278 | 813.753 |
FCKB | CFK | FKB | 1586.879 | 814.755 |
KCFB | CFK | KFB | 1567.768 | 815.677 |
FCKB | KCF | FKB | 1556.495 | 815.865 |
KCFB | FCK | KFB | 1545.086 | 815.905 |
KCFB | KCF | KFB | 1543.966 | 815.987 |
FCKB | CKF | FKB | 1550.423 | 817.128 |
FCKB | FCK | FKB | 1554.190 | 817.297 |
FCBK | FKC | FBK | 1827.928 | 1088.175 |
FCBK | FCK | FBK | 1831.495 | 1090.527 |
FCBK | CFK | FBK | 1889.879 | 1090.808 |
FCBK | KFC | FBK | 1834.945 | 1091.186 |
FCBK | CKF | FBK | 1837.739 | 1091.674 |
FCBK | KCF | FBK | 1838.717 | 1091.918 |
FKBC | FKC | FKB | 1103.236 | 1094.623 |
CFBK | FKC | FBK | 1844.333 | 1094.778 |
CFBK | FCK | FBK | 1843.633 | 1097.725 |
CFBK | KFC | FBK | 1843.886 | 1098.803 |
KFBC | FKC | KFB | 1099.041 | 1099.132 |
CFBK | CFK | FBK | 1877.116 | 1099.701 |
KFBC | CKF | KFB | 1151.676 | 1099.901 |
KFBC | KFC | KFB | 1115.185 | 1100.159 |
KFBC | KCF | KFB | 1108.606 | 1100.315 |
FKBC | KCF | FKB | 1124.793 | 1100.353 |
FKBC | CKF | FKB | 1114.925 | 1100.431 |
CFBK | KCF | FBK | 1855.247 | 1100.536 |
FKBC | KFC | FKB | 1100.970 | 1112.996 |
CFBK | CKF | FBK | 1841.919 | 1101.455 |
FKBC | FCK | FKB | 1102.943 | 1113.840 |
KFBC | FCK | KFB | 1104.292 | 1122.477 |
KFBC | CFK | KFB | 1123.885 | 1115.799 |
KBFC | CKF | KBF | 1846.746 | 1118.036 |
KBFC | CFK | KBF | 1813.970 | 1119.177 |
FKBC | CFK | FKB | 1135.750 | 1119.635 |
KBFC | KFC | KBF | 1809.347 | 1119.766 |
KCBF | KFC | KBF | 1830.465 | 1122.838 |
KBFC | KCF | KBF | 1814.562 | 1123.999 |
KBFC | FKC | KBF | 1813.773 | 1124.754 |
KCBF | KCF | KBF | 1836.964 | 1125.530 |
KCBF | FKC | KBF | 1837.573 | 1126.583 |
KBFC | FCK | KBF | 1815.755 | 1127.580 |
KCBF | CKF | KBF | 1870.705 | 1127.629 |
CKBF | CFK | KBF | 1847.998 | 1128.662 |
KCBF | FCK | KBF | 1836.367 | 1128.818 |
KCBF | CFK | KBF | 1840.175 | 1128.863 |
CKBF | KFC | KBF | 1844.957 | 1130.018 |
CKBF | CKF | KBF | 1874.794 | 1130.837 |
CKBF | KCF | KBF | 1843.871 | 1132.791 |
CKBF | FCK | KBF | 1843.299 | 1133.625 |
CKBF | FKC | KBF | 1848.352 | 1135.878 |
FBKC | FCK | FBK | 1810.536 | 1137.477 |
FBKC | FKC | FBK | 1805.791 | 1137.899 |
FBKC | KFC | FBK | 1856.439 | 1139.148 |
FBKC | KCF | FBK | 1812.757 | 1141.217 |
FBKC | CFK | FBK | 1846.279 | 1142.776 |
FBKC | CKF | FBK | 1817.862 | 1144.043 |
KBCF | KFC | KBF | 1835.764 | 1156.568 |
KBCF | KCF | KBF | 1841.768 | 1157.567 |
KBCF | FKC | KBF | 1854.994 | 1160.787 |
KBCF | CFK | KBF | 1837.696 | 1162.510 |
KBCF | CKF | KBF | 1866.511 | 1162.532 |
KBCF | FCK | KBF | 1836.835 | 1162.555 |
FBCK | FKC | FBK | 1830.111 | 1178.649 |
FBCK | FCK | FBK | 1836.005 | 1180.750 |
FBCK | KFC | FBK | 1853.291 | 1181.822 |
FBCK | KCF | FBK | 1837.737 | 1181.836 |
FBCK | CFK | FBK | 1868.068 | 1182.880 |
FBCK | CKF | FBK | 1840.923 | 1183.692 |
BKFC | KFC | BKF | 2691.465 | 2688.253 |
BFKC | FKC | BFK | 2694.193 | 2695.982 |
BKFC | KCF | BKF | 2695.214 | 2697.435 |
BFKC | CFK | BFK | 2741.626 | 2696.012 |
BFKC | KFC | BFK | 2704.749 | 2700.642 |
BKFC | FKC | BKF | 2701.903 | 2704.405 |
BFKC | CKF | BFK | 2702.773 | 2704.576 |
BFKC | FCK | BFK | 2703.092 | 2703.769 |
BFKC | KCF | BFK | 2704.245 | 2705.473 |
BKFC | FCK | BKF | 2704.247 | 2707.089 |
BKFC | CKF | BKF | 2730.353 | 2708.528 |
BKFC | CFK | BKF | 2719.952 | 2709.968 |
CBKF | KFC | BKF | 4263.530 | 3587.400 |
CBFK | FKC | BFK | 4238.395 | 3591.081 |
CBKF | FKC | BKF | 4269.617 | 3592.677 |
CBKF | KCF | BKF | 4258.191 | 3593.320 |
CBFK | KFC | BFK | 4248.524 | 3594.745 |
CBKF | CFK | BKF | 4258.850 | 3594.814 |
CBKF | CKF | BKF | 4283.232 | 3595.522 |
CBFK | CKF | BFK | 4249.298 | 3596.531 |
CBFK | CFK | BFK | 4304.656 | 3596.606 |
CBFK | FCK | BFK | 4257.009 | 3597.092 |
CBKF | FCK | BKF | 4251.767 | 3597.939 |
CBFK | KCF | BFK | 4248.681 | 3599.644 |
BKCF | FKC | BKF | 3736.378 | 3949.921 |
BKCF | FCK | BKF | 3738.488 | 3961.672 |
BKCF | CFK | BKF | 3740.531 | 3956.702 |
BFCK | KFC | BFK | 3747.073 | 3967.060 |
BFCK | KCF | BFK | 3749.707 | 3968.432 |
BFCK | FCK | BFK | 3752.576 | 3970.594 |
BFCK | FKC | BFK | 3753.589 | 3955.989 |
BKCF | KCF | BKF | 3753.634 | 3955.264 |
BKCF | KFC | BKF | 3757.144 | 3971.041 |
BFCK | CKF | BFK | 3762.168 | 3967.541 |
BKCF | CKF | BKF | 3772.368 | 3957.740 |
BFCK | CFK | BFK | 3793.712 | 3971.665 |
BCFK | FCK | BFK | 4335.131 | 3994.408 |
BCKF | KFC | BKF | 4330.439 | 3994.670 |
BCKF | FKC | BKF | 4336.629 | 3995.968 |
BCFK | KFC | BFK | 4329.605 | 3998.567 |
BCFK | KCF | BFK | 4329.171 | 3999.583 |
BCFK | CFK | BFK | 4368.991 | 4003.686 |
BCKF | FCK | BKF | 4352.477 | 4005.444 |
BCFK | FKC | BFK | 4334.350 | 4009.904 |
BCKF | CFK | BKF | 4333.976 | 4010.229 |
BCKF | CKF | BKF | 4371.505 | 4011.900 |
BCKF | KCF | BKF | 4331.937 | 4013.636 |
BCFK | CKF | BFK | 4335.028 | 4013.674 |
F
=32input | param | output | ipo | pio |
---|---|---|---|---|
CFKB | FKC | FKB | 79.598 | 79.153 |
CKFB | KFC | KFB | 79.566 | 81.817 |
FKCB | FKC | FKB | 79.955 | 97.287 |
CKFB | CKF | KFB | 80.210 | 92.526 |
KFCB | KFC | KFB | 80.294 | 80.422 |
CFKB | CFK | FKB | 80.497 | 99.437 |
FKCB | CFK | FKB | 80.926 | 91.666 |
KFCB | CKF | KFB | 82.341 | 90.974 |
FKCB | FCK | FKB | 92.289 | 92.344 |
CKFB | FCK | KFB | 92.920 | 92.574 |
KFCB | FCK | KFB | 93.374 | 92.578 |
CKFB | FKC | KFB | 92.658 | 94.965 |
KFCB | CFK | KFB | 93.099 | 92.690 |
CKFB | KCF | KFB | 92.756 | 110.558 |
CKFB | CFK | KFB | 92.970 | 92.800 |
KFCB | FKC | KFB | 93.085 | 92.824 |
KFCB | KCF | KFB | 92.919 | 92.848 |
CFKB | KCF | FKB | 94.191 | 92.948 |
CFKB | KFC | FKB | 93.094 | 93.425 |
CFKB | CKF | FKB | 94.684 | 93.358 |
CFKB | FCK | FKB | 93.367 | 93.559 |
FKCB | KFC | FKB | 93.462 | 93.947 |
FKCB | CKF | FKB | 94.052 | 94.461 |
FKCB | KCF | FKB | 94.733 | 94.473 |
KCFB | KFC | KFB | 211.477 | 130.044 |
FCKB | FKC | FKB | 210.620 | 130.102 |
KCFB | CKF | KFB | 215.384 | 132.373 |
FCKB | CFK | FKB | 215.085 | 133.753 |
KCFB | KCF | KFB | 213.093 | 137.435 |
KCFB | CFK | KFB | 214.065 | 137.480 |
FCKB | KFC | FKB | 213.780 | 137.704 |
KCFB | FCK | KFB | 213.443 | 137.870 |
KCFB | FKC | KFB | 213.155 | 137.957 |
FCKB | FCK | FKB | 214.328 | 138.155 |
FCKB | KCF | FKB | 214.054 | 138.580 |
FCKB | CKF | FKB | 213.860 | 138.975 |
KBFC | KFC | KBF | 238.672 | 157.704 |
KBCF | KFC | KBF | 240.246 | 158.713 |
KCBF | KFC | KBF | 255.124 | 158.941 |
CKBF | KFC | KBF | 241.658 | 160.352 |
FKBC | FKC | FKB | 161.263 | 160.596 |
KFBC | KFC | KFB | 160.904 | 161.044 |
KBFC | CKF | KBF | 243.652 | 161.269 |
KBCF | CKF | KBF | 246.309 | 162.309 |
FCBK | FKC | FBK | 243.916 | 162.420 |
FKBC | CFK | FKB | 162.771 | 172.927 |
CKBF | CKF | KBF | 247.095 | 162.829 |
KCBF | CKF | KBF | 246.178 | 163.045 |
KFBC | CKF | KFB | 163.281 | 173.811 |
CFBK | FKC | FBK | 246.422 | 164.112 |
FBKC | FKC | FBK | 242.724 | 165.993 |
CFBK | CFK | FBK | 250.655 | 166.025 |
FCBK | CFK | FBK | 249.122 | 166.101 |
KCBF | FKC | KBF | 244.671 | 166.421 |
KBCF | CFK | KBF | 243.866 | 166.584 |
KBFC | FKC | KBF | 242.311 | 166.680 |
KCBF | CFK | KBF | 244.461 | 166.827 |
KCBF | KCF | KBF | 244.174 | 166.907 |
KCBF | FCK | KBF | 244.131 | 166.992 |
KBFC | KCF | KBF | 241.301 | 167.124 |
KBCF | FCK | KBF | 243.406 | 167.519 |
CKBF | KCF | KBF | 244.784 | 167.618 |
KBCF | KCF | KBF | 243.740 | 167.685 |
KBFC | FCK | KBF | 241.152 | 167.970 |
KBFC | CFK | KBF | 241.688 | 168.077 |
CKBF | FCK | KBF | 244.305 | 168.107 |
FBCK | FKC | FBK | 244.686 | 168.707 |
KBCF | FKC | KBF | 244.034 | 169.067 |
CKBF | CFK | KBF | 245.370 | 169.176 |
FBKC | CFK | FBK | 247.511 | 169.392 |
BKCF | KFC | BKF | 248.457 | 169.842 |
BFCK | FKC | BFK | 248.430 | 169.880 |
CKBF | FKC | KBF | 244.315 | 169.925 |
FCBK | KCF | FBK | 247.931 | 170.809 |
FCBK | CKF | FBK | 256.509 | 170.834 |
CFBK | KCF | FBK | 249.278 | 171.005 |
CFBK | KFC | FBK | 249.086 | 171.034 |
FBCK | CFK | FBK | 249.342 | 171.165 |
CFBK | FCK | FBK | 249.204 | 171.501 |
CFBK | CKF | FBK | 248.849 | 172.072 |
FCBK | KFC | FBK | 248.645 | 172.564 |
FBKC | KCF | FBK | 268.819 | 172.959 |
FKBC | FCK | FKB | 173.119 | 174.640 |
FBKC | KFC | FBK | 245.996 | 173.209 |
FBKC | FCK | FBK | 245.642 | 173.410 |
FBKC | CKF | FBK | 246.213 | 173.414 |
KFBC | FKC | KFB | 173.680 | 175.566 |
FKBC | KCF | FKB | 173.894 | 174.260 |
BKCF | CKF | BKF | 251.848 | 173.899 |
KFBC | KCF | KFB | 173.975 | 176.086 |
FKBC | KFC | FKB | 173.998 | 174.837 |
KFBC | FCK | KFB | 174.170 | 174.683 |
KFBC | CFK | KFB | 174.195 | 175.851 |
FKBC | CKF | FKB | 174.294 | 175.056 |
CBFK | FKC | BFK | 258.601 | 174.520 |
BFCK | CFK | BFK | 253.164 | 174.660 |
BCFK | FKC | BFK | 259.629 | 174.748 |
CBKF | KFC | BKF | 258.871 | 175.114 |
BCKF | KFC | BKF | 259.257 | 175.316 |
BCFK | CFK | BFK | 263.828 | 176.710 |
FBCK | CKF | FBK | 247.714 | 176.801 |
FBCK | KCF | FBK | 247.170 | 177.100 |
BCKF | CKF | BKF | 265.588 | 177.627 |
FBCK | FCK | FBK | 247.046 | 178.085 |
FBCK | KFC | FBK | 247.356 | 178.092 |
CBFK | CFK | BFK | 264.375 | 178.169 |
BKCF | CFK | BKF | 250.636 | 178.263 |
BKCF | FCK | BKF | 250.638 | 178.286 |
CBKF | CKF | BKF | 263.181 | 178.454 |
BFCK | KCF | BFK | 252.320 | 178.705 |
BKCF | KCF | BKF | 250.549 | 178.750 |
BFCK | KFC | BFK | 251.316 | 179.074 |
BFCK | CKF | BFK | 251.315 | 179.160 |
BKCF | FKC | BKF | 250.134 | 180.515 |
CBFK | KFC | BFK | 266.976 | 182.966 |
FCBK | FCK | FBK | 247.765 | 183.024 |
BCFK | KFC | BFK | 268.268 | 183.291 |
BCKF | KCF | BKF | 267.738 | 183.309 |
CBFK | KCF | BFK | 267.570 | 183.435 |
BCKF | FKC | BKF | 269.241 | 183.463 |
BCKF | CFK | BKF | 268.594 | 183.560 |
CBKF | CFK | BKF | 267.082 | 183.668 |
CBKF | FKC | BKF | 267.374 | 183.678 |
CBKF | FCK | BKF | 266.321 | 183.830 |
BCFK | CKF | BFK | 269.935 | 183.875 |
CBFK | CKF | BFK | 275.909 | 184.310 |
CBKF | KCF | BKF | 267.492 | 184.363 |
BCFK | KCF | BFK | 267.446 | 184.509 |
BCFK | FCK | BFK | 268.312 | 184.614 |
CBFK | FCK | BFK | 266.665 | 184.909 |
BCKF | FCK | BKF | 267.597 | 185.701 |
BFCK | FCK | BFK | 250.912 | 189.132 |
BFKC | FKC | BFK | 207.639 | 207.256 |
BKFC | KFC | BKF | 208.248 | 275.361 |
BFKC | CFK | BFK | 209.477 | 217.516 |
BKFC | CKF | BKF | 209.764 | 219.431 |
BFKC | FCK | BFK | 218.109 | 219.913 |
BKFC | FKC | BKF | 218.851 | 219.514 |
BFKC | CKF | BFK | 219.123 | 218.918 |
BFKC | KCF | BFK | 218.981 | 219.248 |
BKFC | CFK | BKF | 220.066 | 219.263 |
BFKC | KFC | BFK | 220.841 | 219.490 |
BKFC | KCF | BKF | 220.788 | 219.631 |
BKFC | FCK | BKF | 220.577 | 220.304 |
K
=512 causes OOM, 128 is too slow
differ from CP by
K = 64
sizes = {
"B": B,
"K": K,
"F": F,
"I": K,
"J": K,
}
...
def bench(inputs: str, inputsj: str, params: str, outputs: str, order: str) -> float:
input_data = torch.rand([sizes[x] for x in inputs], device=device)
inputj_data = torch.rand([sizes[x] for x in inputsj], device=device)
param_data = torch.rand([sizes[x] for x in params], device=device)
equation_ijpo = f"{inputs},{inputsj},{params}->{outputs}" # jipo should be the same
equation_ipjo = f"{inputs},{params},{inputsj}->{outputs}" # jpio should be the same
equation_pijo = f"{params},{inputs},{inputsj}->{outputs}" # pjio should be the same
def proc_ijpo() -> Tensor:
result = torch.einsum(equation_ijpo, input_data, inputj_data, param_data)
return result.contiguous() # penalize uncontiguous
def proc_ipjo() -> Tensor:
result = torch.einsum(equation_ipjo, input_data, param_data, inputj_data)
return result.contiguous() # penalize uncontiguous
def proc_pijo() -> Tensor:
result = torch.einsum(equation_pijo, param_data, input_data, inputj_data)
return result.contiguous() # penalize uncontiguous
proc = proc_ijpo if order == "ijpo" else proc_ipjo if order == "ipjo" else proc_pijo
_, t = timer(proc)
return t
...
for inputs in itertools.permutations("BIF"):
for params in itertools.permutations("IJKF"):
inputs = "".join(inputs)
inputsj = inputs.replace("I", "J")
params = "".join(params)
outputs = inputs.replace("I", "K")
ans.append(["i", inputs, inputsj, "p", params, "o", outputs])
for order in ("ijpo", "ipjo", "pijo"):
...
ans.sort(key=lambda x: min(x[-1], x[-3], x[-5]))
input_i | input_j | param | output | ijpo | ipjo | pijo |
---|---|---|---|---|---|---|
FBI | FBJ | IFKJ | FBK | 3908.628 | 2377.637 | 4540.433 |
FBI | FBJ | FIKJ | FBK | 3836.590 | 2412.078 | 4520.354 |
IFB | JFB | FIKJ | KFB | 4609.518 | 2422.983 | 4597.182 |
FIB | FJB | FIKJ | FKB | 3870.240 | 2430.586 | 4614.175 |
IFB | JFB | IFKJ | KFB | 4684.945 | 2434.715 | 4612.133 |
FIB | FJB | IFKJ | FKB | 3942.264 | 2439.313 | 4604.340 |
FBI | FBJ | FKJI | FBK | 4219.890 | 2495.123 | 4641.883 |
IFB | JFB | FKJI | KFB | 4953.674 | 2554.761 | 4723.499 |
FIB | FJB | FKJI | FKB | 4252.273 | 2559.678 | 4692.171 |
FBI | FBJ | KJFI | FBK | 8490.646 | 2727.708 | 4868.239 |
IFB | JFB | KJFI | KFB | 9283.930 | 2772.792 | 4922.008 |
FIB | FJB | KJFI | FKB | 8507.524 | 2781.015 | 4916.187 |
FBI | FBJ | FIJK | FBK | 2996.839 | 3284.207 | 5895.220 |
FBI | FBJ | IJFK | FBK | 3005.752 | 3358.581 | 10215.086 |
FIB | FJB | IJFK | FKB | 3033.671 | 3426.286 | 10236.571 |
BFI | BFJ | FIJK | BFK | 3035.726 | 5004.268 | 15060.040 |
FIB | FJB | FIJK | FKB | 3036.479 | 3360.912 | 5961.381 |
FBI | FBJ | FKIJ | FBK | 3055.769 | 3299.696 | 5487.676 |
FIB | FJB | FKIJ | FKB | 3058.600 | 3360.504 | 5559.431 |
BFI | BFJ | IJFK | BFK | 3068.995 | 5082.056 | 19591.741 |
FBI | FBJ | KFIJ | FBK | 3075.738 | 3310.673 | 5573.188 |
FIB | FJB | KFIJ | FKB | 3099.481 | 3370.923 | 5636.986 |
BFI | BFJ | FKIJ | BFK | 3113.800 | 5056.174 | 14790.037 |
FBI | FBJ | FJIK | FBK | 3803.535 | 3287.835 | 5724.976 |
FBI | FBJ | KIFJ | FBK | 4724.184 | 3337.369 | 5569.674 |
IFB | JFB | FKIJ | KFB | 3797.939 | 3352.448 | 5561.146 |
IFB | JFB | FJIK | KFB | 4588.908 | 3352.739 | 5794.903 |
IFB | JFB | FIJK | KFB | 3754.884 | 3353.990 | 5969.906 |
FBI | FBJ | IFJK | FBK | 3834.165 | 3356.635 | 9768.044 |
FIB | FJB | FJIK | FKB | 3835.691 | 3363.097 | 5775.134 |
IFB | JFB | KFIJ | KFB | 3814.360 | 3364.032 | 5635.572 |
FBI | FBJ | IKFJ | FBK | 3912.668 | 3385.963 | 6434.167 |
IFB | JFB | IKFJ | KFB | 4697.286 | 3397.757 | 6445.800 |
IFB | JFB | KIFJ | KFB | 5409.265 | 3398.406 | 5643.911 |
FIB | FJB | IKFJ | FKB | 3938.763 | 3404.906 | 6463.434 |
IFB | JFB | IFJK | KFB | 4616.902 | 3419.826 | 9821.942 |
FIB | FJB | IFJK | FKB | 3848.678 | 3420.680 | 9809.559 |
IFB | JFB | IJFK | KFB | 3780.082 | 3421.240 | 10328.697 |
FIB | FJB | KIFJ | FKB | 4753.905 | 3475.389 | 5647.986 |
FBI | FBJ | FJKI | FBK | 4061.292 | 3592.565 | 5481.301 |
BFI | BFJ | KFIJ | BFK | 3634.948 | 5039.832 | 14910.924 |
FIB | FJB | FJKI | FKB | 4106.275 | 3677.640 | 5539.523 |
IFB | JFB | FJKI | KFB | 4804.897 | 3680.428 | 5553.464 |
FBI | FBJ | KFJI | FBK | 8011.818 | 3825.199 | 5505.485 |
FBI | FBJ | JFIK | FBK | 3838.637 | 4056.771 | 6103.444 |
FIB | FJB | JFIK | FKB | 3856.767 | 4117.399 | 6184.915 |
BFI | BFJ | FJIK | BFK | 3858.775 | 5006.212 | 14866.041 |
FBI | FBJ | JIFK | FBK | 3864.836 | 4140.201 | 6516.918 |
BFI | BFJ | JFIK | BFK | 3874.823 | 5709.612 | 15325.226 |
BFI | BFJ | FIKJ | BFK | 3880.999 | 4171.772 | 13823.464 |
BFI | BFJ | IFJK | BFK | 3886.648 | 5081.302 | 19082.015 |
FIB | FJB | JIFK | FKB | 3895.076 | 4194.846 | 6578.682 |
BFI | BFJ | JIFK | BFK | 3914.844 | 5917.165 | 15772.189 |
FIB | FJB | KFJI | FKB | 8009.399 | 3919.569 | 5574.401 |
IFB | JFB | KFJI | KFB | 8816.070 | 3924.129 | 5578.458 |
BFI | BFJ | IKFJ | BFK | 3961.212 | 5080.253 | 15505.795 |
BFI | BFJ | IFKJ | BFK | 3965.693 | 4144.332 | 13822.074 |
IFB | JFB | JFIK | KFB | 4610.678 | 4119.573 | 6187.474 |
BFI | BFJ | FJKI | BFK | 4132.121 | 5280.227 | 14866.309 |
IBF | JBF | FIKJ | KBF | 27502.828 | 4217.491 | 13988.614 |
IBF | JBF | IFKJ | KBF | 27565.654 | 4240.776 | 13989.740 |
BIF | BJF | IFKJ | BKF | 25198.791 | 4242.847 | 14001.018 |
BFI | BFJ | FKJI | BFK | 4274.590 | 4245.031 | 13959.578 |
BIF | BJF | FIKJ | BKF | 25149.931 | 4248.589 | 13990.308 |
IFB | JFB | JIFK | KFB | 4692.126 | 4256.036 | 6582.374 |
IBF | JBF | FKJI | KBF | 27787.650 | 4345.416 | 14047.371 |
BIF | BJF | FKJI | BKF | 25354.145 | 4350.012 | 14065.164 |
FBI | FBJ | JFKI | FBK | 4452.561 | 7527.268 | 5503.660 |
BFI | BFJ | KJFI | BFK | 8576.348 | 4476.279 | 14196.162 |
FIB | FJB | JFKI | FKB | 4492.634 | 7585.673 | 5571.520 |
BFI | BFJ | JFKI | BFK | 4509.255 | 9251.932 | 14863.500 |
IBF | JBF | KJFI | KBF | 32316.516 | 4578.232 | 14280.320 |
BIF | BJF | KJFI | BKF | 29864.192 | 4589.368 | 14318.623 |
BFI | BFJ | KIFJ | BFK | 4798.598 | 5087.361 | 14921.010 |
FBI | FBJ | JKFI | FBK | 4814.459 | 8022.303 | 5563.618 |
FIB | FJB | JKFI | FKB | 4849.267 | 8087.695 | 5616.444 |
BFI | BFJ | JKFI | BFK | 4875.337 | 9739.322 | 14910.316 |
IBF | JBF | FIJK | KBF | 26759.803 | 5108.515 | 15160.152 |
IBF | JBF | FJIK | KBF | 27506.458 | 5112.371 | 15014.336 |
IBF | JBF | FKIJ | KBF | 26837.963 | 5114.460 | 14877.931 |
BIF | BJF | FIJK | BKF | 24192.515 | 5115.876 | 15173.191 |
BIF | BJF | FJIK | BKF | 25152.414 | 5120.059 | 15036.619 |
IBF | JBF | KFIJ | KBF | 26863.712 | 5124.499 | 14989.815 |
BIF | BJF | FKIJ | BKF | 24274.233 | 5129.827 | 14953.184 |
IBF | JBF | KIFJ | KBF | 28279.598 | 5155.905 | 15007.012 |
BIF | BJF | KFIJ | BKF | 24265.052 | 5158.428 | 14946.707 |
IBF | JBF | IKFJ | KBF | 27563.420 | 5160.645 | 15704.193 |
BIF | BJF | KIFJ | BKF | 25875.848 | 5170.673 | 15018.984 |
BIF | BJF | IFJK | BKF | 25144.943 | 5171.858 | 19147.608 |
BIF | BJF | IKFJ | BKF | 25176.313 | 5175.582 | 15735.124 |
IBF | JBF | IFJK | KBF | 27518.212 | 5184.266 | 19173.475 |
IFB | JFB | JFKI | KFB | 5187.517 | 7563.389 | 5566.586 |
BIF | BJF | IJFK | BKF | 24138.406 | 5189.429 | 19573.731 |
IBF | JBF | IJFK | KBF | 26763.469 | 5189.476 | 19678.936 |
IBF | JBF | FJKI | KBF | 27574.461 | 5417.717 | 14941.064 |
BIF | BJF | FJKI | BKF | 25214.373 | 5440.717 | 14966.369 |
BFI | BFJ | KFJI | BFK | 8066.536 | 5535.890 | 14883.697 |
IFB | JFB | JKFI | KFB | 5571.748 | 8084.470 | 5613.728 |
IBF | JBF | KFJI | KBF | 31738.359 | 5644.784 | 14951.942 |
BIF | BJF | KFJI | BKF | 29385.698 | 5670.004 | 14973.997 |
IBF | JBF | JFIK | KBF | 27594.927 | 5830.726 | 15416.097 |
BIF | BJF | JFIK | BKF | 25172.149 | 5888.308 | 15422.273 |
IBF | JBF | JIFK | KBF | 27551.375 | 5945.994 | 15865.789 |
BIF | BJF | JIFK | BKF | 25206.292 | 6029.212 | 15922.171 |
BIF | BJF | JFKI | BKF | 25612.969 | 9409.653 | 14967.835 |
IBF | JBF | JFKI | KBF | 27972.215 | 9420.192 | 14950.724 |
IBF | JBF | JKFI | KBF | 28428.455 | 9893.919 | 14986.498 |
BIF | BJF | JKFI | BKF | 26063.977 | 9935.555 | 15000.447 |
FBI | FBJ | KIJF | FBK | 14408.882 | 13102.867 | 16097.338 |
FBI | FBJ | IKJF | FBK | 14404.375 | 13131.181 | 16037.176 |
IFB | JFB | KIJF | KFB | 15232.959 | 13169.579 | 16141.654 |
FIB | FJB | KIJF | FKB | 14401.401 | 13185.572 | 16049.590 |
IFB | JFB | IKJF | KFB | 15242.951 | 13200.043 | 16086.763 |
FIB | FJB | IKJF | FKB | 14377.524 | 13210.020 | 16100.169 |
FBI | FBJ | JIKF | FBK | 13402.432 | 14042.264 | 16180.082 |
FIB | FJB | IJKF | FKB | 13423.102 | 14013.476 | 16088.511 |
FIB | FJB | JIKF | FKB | 13487.001 | 14122.503 | 16237.219 |
FBI | FBJ | IJKF | FBK | 13543.469 | 14039.374 | 16032.647 |
BFI | BFJ | JIKF | BFK | 13552.916 | 15769.723 | 25733.966 |
BFI | BFJ | IJKF | BFK | 13600.586 | 15665.860 | 25701.877 |
IFB | JFB | KJIF | KFB | 15255.800 | 13976.967 | 15237.502 |
FBI | FBJ | KJIF | FBK | 14341.887 | 14003.328 | 15188.944 |
FIB | FJB | KJIF | FKB | 14406.042 | 14078.892 | 15246.011 |
FBI | FBJ | JKIF | FBK | 14426.273 | 14091.921 | 15197.412 |
IFB | JFB | JIKF | KFB | 14328.805 | 14104.929 | 16224.251 |
IFB | JFB | IJKF | KFB | 14390.491 | 14105.099 | 16204.005 |
IFB | JFB | JKIF | KFB | 15256.116 | 14158.904 | 15214.613 |
FIB | FJB | JKIF | FKB | 14316.826 | 14171.641 | 15126.375 |
BFI | BFJ | KIJF | BFK | 14451.512 | 14831.610 | 25648.642 |
BFI | BFJ | IKJF | BFK | 14452.288 | 14849.012 | 25409.622 |
BFI | BFJ | JKIF | BFK | 14469.899 | 15705.037 | 24719.888 |
BFI | BFJ | KJIF | BFK | 14470.129 | 15723.971 | 24585.098 |
IBF | JBF | KIJF | KBF | 38229.830 | 14972.211 | 25722.334 |
BIF | BJF | KIJF | BKF | 35864.951 | 14990.372 | 25745.150 |
IBF | JBF | IKJF | KBF | 38231.584 | 14996.291 | 25470.399 |
BIF | BJF | IKJF | BKF | 35800.384 | 15007.770 | 25687.052 |
IBF | JBF | KJIF | KBF | 38252.489 | 15779.133 | 24648.695 |
BIF | BJF | KJIF | BKF | 35867.406 | 15791.552 | 24823.959 |
IBF | JBF | IJKF | KBF | 37369.621 | 15907.600 | 25611.660 |
IBF | JBF | JIKF | KBF | 37326.194 | 15908.640 | 25807.500 |
BIF | BJF | IJKF | BKF | 34922.554 | 15912.334 | 25612.394 |
BIF | BJF | JIKF | BKF | 34942.249 | 15927.255 | 25822.737 |
IBF | JBF | JKIF | KBF | 38252.931 | 15963.755 | 24788.428 |
BIF | BJF | JKIF | BKF | 35872.001 | 15972.883 | 24653.693 |
This issue is the tracker for benchmarks on pytorch ops with different axes order.
All time counted in $\mu s$.
Notations:
B
: batchK
(alsoI
,J
): unitF
: foldingR
: rankC
: componentConclusions:
torch.bmm
(possibly with aTensor.expand
to manually broadcast) andtorch.matmul
is the same astorch.einsum
in underlying calculations (all reduced tobmm
->baddbmm
-> bgemm in BLAS).