Closed bclarkson-code closed 2 months ago
I ran this benchmark:
import cupy as cp
N_LOOPS = 100
IN_SHAPE = (64, 768, 384)
OUT_SHAPE = (384, 384)
DEVICE = 0
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
# select GPU
cp.cuda.Device(DEVICE).use()
# Make sure the calculations produce the same output
assert cp.allclose(
cp.tensordot(x, g, axes=[[0, 1], [0, 1]]), cp.einsum("zTb,zTW->bW", x, g)
)
def tensordot_full_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
out = cp.zeros(OUT_SHAPE).astype(cp.float16)
for _ in range(N_LOOPS):
out += cp.tensordot(x, g, axes=[[0, 1], [0, 1]])
def tensordot_half_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float16)
g = cp.random.random(IN_SHAPE).astype(cp.float16)
out = cp.zeros(OUT_SHAPE).astype(cp.float16)
for _ in range(N_LOOPS):
out += cp.tensordot(x, g, axes=[[0, 1], [0, 1]])
def einsum_full_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
out = cp.zeros(OUT_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += cp.einsum("zTb,zTW->bW", x, g)
def einsum_half_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float16)
g = cp.random.random(IN_SHAPE).astype(cp.float16)
out = cp.zeros(OUT_SHAPE).astype(cp.float16)
for _ in range(N_LOOPS):
out += cp.einsum("zTb,zTW->bW", x, g)
__benchmarks__ = [
(
einsum_full_precision,
tensordot_full_precision,
"tensordot full precision",
),
(
einsum_full_precision,
tensordot_half_precision,
"tensordot half precision",
),
(einsum_full_precision, einsum_full_precision, "einsum full precision"),
(einsum_full_precision, einsum_half_precision, "einsum half precision"),
]
Which produced this output
$ richbench benchmarks/
Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Benchmark ┃ Min ┃ Max ┃ Mean ┃ Min (+) ┃ Max (+) ┃ Mean (+) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ tensordot full precision │ 0.831 │ 1.030 │ 0.990 │ 0.413 (2.0x) │ 0.525 (2.0x) │ 0.439 (2.3x) │
│ tensordot half precision │ 0.924 │ 1.034 │ 1.012 │ 0.125 (7.4x) │ 0.365 (2.8x) │ 0.173 (5.8x) │
│ einsum full precision │ 0.869 │ 1.037 │ 1.002 │ 1.036 (-1.2x) │ 1.038 (-1.0x) │ 1.037 (-1.0x) │
│ einsum half precision │ 1.038 │ 1.041 │ 1.039 │ 0.977 (1.1x) │ 0.994 (1.0x) │ 0.980 (1.1x) │
└──────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘
Looks like tensordot is both significantly faster than einsum, and works much faster with 16 bit floats. Looks like the correct optimisation is to switch to tensordot and. use 16 bit floats
Where an operation could be written with matmul, tensordot seems to have identical performance:
import cupy as cp
N_LOOPS = 100
IN_SHAPE = (64, 768, 384)
OUT_SHAPE = (384, 384)
DEVICE = 0
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
# select GPU
cp.cuda.Device(DEVICE).use()
# Make sure the calculations produce the same output
assert cp.allclose(
cp.tensordot(x, w, axes=[[2], [0]]), cp.einsum("zTb,bW->zTW", x, w)
)
def matmul_full_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += x @ w
def matmul_half_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += x @ w
def tensordot_full_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += cp.tensordot(x, w, axes=[-1, -1])
def tensordot_half_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += cp.tensordot(x, w, axes=[-1, -1])
def einsum_full_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += cp.einsum("zTb,bW->zTW", x, w)
def einsum_half_precision():
cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)
out = cp.zeros(IN_SHAPE).astype(cp.float32)
for _ in range(N_LOOPS):
out += cp.einsum("zTb,bW->zTW", x, w)
__benchmarks__ = [
(
einsum_full_precision,
tensordot_full_precision,
"tensordot full precision",
),
(
einsum_full_precision,
tensordot_half_precision,
"tensordot half precision",
),
(einsum_full_precision, einsum_full_precision, "einsum full precision"),
(einsum_full_precision, einsum_half_precision, "einsum half precision"),
(einsum_full_precision, matmul_full_precision, "matmul full precision"),
(einsum_full_precision, matmul_half_precision, "matmul half precision"),
]
Returned
$ richbench benchmarks/
Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Benchmark ┃ Min ┃ Max ┃ Mean ┃ Min (+) ┃ Max (+) ┃ Mean (+) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ tensordot full precision │ 1.670 │ 2.068 │ 1.987 │ 0.536 (3.1x) │ 0.833 (2.5x) │ 0.597 (3.3x) │
│ tensordot half precision │ 1.781 │ 2.073 │ 2.013 │ 0.538 (3.3x) │ 0.832 (2.5x) │ 0.599 (3.4x) │
│ einsum full precision │ 1.785 │ 2.076 │ 2.018 │ 2.076 (-1.2x) │ 2.078 (-1.0x) │ 2.077 (-1.0x) │
│ einsum half precision │ 2.078 │ 2.079 │ 2.078 │ 2.078 (-1.0x) │ 2.079 (-1.0x) │ 2.079 (-1.0x) │
│ matmul full precision │ 2.078 │ 2.079 │ 2.079 │ 0.528 (3.9x) │ 0.826 (2.5x) │ 0.594 (3.5x) │
│ matmul half precision │ 1.787 │ 2.082 │ 2.022 │ 0.529 (3.4x) │ 0.828 (2.5x) │ 0.594 (3.4x) │
└──────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘
Looks like tensordot is strictly better than einsum and is equivalent to matmul where matmul is possible. Given that it also benefits from 16 bit float calculations, it should be the preferred choice.
The next step is to replace the einsums in Dense
with tensordots and measure performance.
I ran the following benchmark:
Which returned the following:
This suggests two surprising things:
This clearly suggests that einsum should be replace with traditional matrix multiplication wherever possible. Second, half precision should be used sparingly. It appears that matrix multiplication is more optimised for full precision than half precision. It remains to be seen whether the memory savings are worth the increase in processing time (each operation might be slower but we might be able to do more in parallel with half precision).