bclarkson-code / Tricycle

Autograd to GPT-2 completely from scratch
104 stars 9 forks source link

Optimise matrix multiplication #78

Closed bclarkson-code closed 2 months ago

bclarkson-code commented 3 months ago

I ran the following benchmark:

import cupy as cp

N_LOOPS = 250
IN_SHAPE = (64, 768, 384)
WEIGHT_SHAPE = (384, 384)
DEVICE = 0

cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(WEIGHT_SHAPE).astype(cp.float32)

# select GPU
cp.cuda.Device(DEVICE).use()

# Make sure the calculations produce the same output
assert cp.allclose(x @ w, cp.einsum("zTb,bW->zTW", x, w))

def matmul_full_precision_with_at_symbol():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += x @ w

def matmul_half_precision_with_at_symbol():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float16)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float16)
    out = cp.zeros(IN_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += x @ w

def matmul_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.matmul(x, w)

def matmul_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float16)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float16)
    out = cp.zeros(IN_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += cp.matmul(x, w)

def einsum_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,bW->zTW", x, w)

def einsum_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float16)
    w = cp.random.random(WEIGHT_SHAPE).astype(cp.float16)
    out = cp.zeros(IN_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,bW->zTW", x, w)

__benchmarks__ = [
    (einsum_full_precision, matmul_full_precision, "matmul full precision"),
    (einsum_full_precision, matmul_half_precision, "matmul half precision"),
    (
        einsum_full_precision,
        matmul_full_precision_with_at_symbol,
        "matmul full precision with at symbol",
    ),
    (
        einsum_full_precision,
        matmul_half_precision_with_at_symbol,
        "matmul half precision with at symbol",
    ),
    (einsum_full_precision, einsum_full_precision, "einsum full precision"),
    (einsum_full_precision, einsum_half_precision, "einsum half precision"),
]

Which returned the following:

$ richbench benchmarks/
                                               Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃                            Benchmark ┃ Min     ┃ Max     ┃ Mean    ┃ Min (+)         ┃ Max (+)         ┃ Mean (+)        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│                matmul full precision │ 4.126   │ 5.106   │ 4.906   │ 1.360 (3.0x)    │ 2.098 (2.4x)    │ 1.511 (3.2x)    │
│                matmul half precision │ 4.387   │ 5.127   │ 4.976   │ 1.517 (2.9x)    │ 2.255 (2.3x)    │ 1.667 (3.0x)    │
│ matmul full precision with at symbol │ 4.402   │ 5.152   │ 4.998   │ 1.317 (3.3x)    │ 2.093 (2.5x)    │ 1.476 (3.4x)    │
│ matmul half precision with at symbol │ 4.387   │ 5.172   │ 5.012   │ 1.473 (3.0x)    │ 2.316 (2.2x)    │ 1.645 (3.0x)    │
│                einsum full precision │ 4.331   │ 5.175   │ 5.006   │ 5.175 (-1.2x)   │ 5.178 (-1.0x)   │ 5.176 (-1.0x)   │
│                einsum half precision │ 5.178   │ 5.188   │ 5.184   │ 4.861 (1.1x)    │ 5.197 (-1.0x)   │ 4.928 (1.1x)    │
└──────────────────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘

This suggests two surprising things:

  1. Matmul is significantly faster than einsum
  2. For Matmul, half precision is slightly slower than full precision

This clearly suggests that einsum should be replace with traditional matrix multiplication wherever possible. Second, half precision should be used sparingly. It appears that matrix multiplication is more optimised for full precision than half precision. It remains to be seen whether the memory savings are worth the increase in processing time (each operation might be slower but we might be able to do more in parallel with half precision).

bclarkson-code commented 3 months ago

I ran this benchmark:

import cupy as cp

N_LOOPS = 100
IN_SHAPE = (64, 768, 384)
OUT_SHAPE = (384, 384)
DEVICE = 0

cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)

# select GPU
cp.cuda.Device(DEVICE).use()

# Make sure the calculations produce the same output
assert cp.allclose(
    cp.tensordot(x, g, axes=[[0, 1], [0, 1]]), cp.einsum("zTb,zTW->bW", x, g)
)

def tensordot_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    g = cp.random.random(IN_SHAPE).astype(cp.float32)
    out = cp.zeros(OUT_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += cp.tensordot(x, g, axes=[[0, 1], [0, 1]])

def tensordot_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float16)
    g = cp.random.random(IN_SHAPE).astype(cp.float16)
    out = cp.zeros(OUT_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += cp.tensordot(x, g, axes=[[0, 1], [0, 1]])

def einsum_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    g = cp.random.random(IN_SHAPE).astype(cp.float32)
    out = cp.zeros(OUT_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,zTW->bW", x, g)

def einsum_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float16)
    g = cp.random.random(IN_SHAPE).astype(cp.float16)
    out = cp.zeros(OUT_SHAPE).astype(cp.float16)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,zTW->bW", x, g)

__benchmarks__ = [
    (
        einsum_full_precision,
        tensordot_full_precision,
        "tensordot full precision",
    ),
    (
        einsum_full_precision,
        tensordot_half_precision,
        "tensordot half precision",
    ),
    (einsum_full_precision, einsum_full_precision, "einsum full precision"),
    (einsum_full_precision, einsum_half_precision, "einsum half precision"),
]

Which produced this output

$ richbench benchmarks/
                                         Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃                Benchmark ┃ Min     ┃ Max     ┃ Mean    ┃ Min (+)         ┃ Max (+)         ┃ Mean (+)        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ tensordot full precision │ 0.831   │ 1.030   │ 0.990   │ 0.413 (2.0x)    │ 0.525 (2.0x)    │ 0.439 (2.3x)    │
│ tensordot half precision │ 0.924   │ 1.034   │ 1.012   │ 0.125 (7.4x)    │ 0.365 (2.8x)    │ 0.173 (5.8x)    │
│    einsum full precision │ 0.869   │ 1.037   │ 1.002   │ 1.036 (-1.2x)   │ 1.038 (-1.0x)   │ 1.037 (-1.0x)   │
│    einsum half precision │ 1.038   │ 1.041   │ 1.039   │ 0.977 (1.1x)    │ 0.994 (1.0x)    │ 0.980 (1.1x)    │
└──────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘

Looks like tensordot is both significantly faster than einsum, and works much faster with 16 bit floats. Looks like the correct optimisation is to switch to tensordot and. use 16 bit floats

bclarkson-code commented 3 months ago

Where an operation could be written with matmul, tensordot seems to have identical performance:

import cupy as cp

N_LOOPS = 100
IN_SHAPE = (64, 768, 384)
OUT_SHAPE = (384, 384)
DEVICE = 0

cp.random.seed(0)
x = cp.random.random(IN_SHAPE).astype(cp.float32)
g = cp.random.random(IN_SHAPE).astype(cp.float32)
w = cp.random.random(OUT_SHAPE).astype(cp.float32)

# select GPU
cp.cuda.Device(DEVICE).use()

# Make sure the calculations produce the same output
assert cp.allclose(
    cp.tensordot(x, w, axes=[[2], [0]]), cp.einsum("zTb,bW->zTW", x, w)
)

def matmul_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += x @ w

def matmul_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += x @ w

def tensordot_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.tensordot(x, w, axes=[-1, -1])

def tensordot_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.tensordot(x, w, axes=[-1, -1])

def einsum_full_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,bW->zTW", x, w)

def einsum_half_precision():
    cp.random.seed(0)
    x = cp.random.random(IN_SHAPE).astype(cp.float32)
    w = cp.random.random(OUT_SHAPE).astype(cp.float32)
    out = cp.zeros(IN_SHAPE).astype(cp.float32)

    for _ in range(N_LOOPS):
        out += cp.einsum("zTb,bW->zTW", x, w)

__benchmarks__ = [
    (
        einsum_full_precision,
        tensordot_full_precision,
        "tensordot full precision",
    ),
    (
        einsum_full_precision,
        tensordot_half_precision,
        "tensordot half precision",
    ),
    (einsum_full_precision, einsum_full_precision, "einsum full precision"),
    (einsum_full_precision, einsum_half_precision, "einsum half precision"),
    (einsum_full_precision, matmul_full_precision, "matmul full precision"),
    (einsum_full_precision, matmul_half_precision, "matmul half precision"),
]

Returned

$ richbench benchmarks/
                                         Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃                Benchmark ┃ Min     ┃ Max     ┃ Mean    ┃ Min (+)         ┃ Max (+)         ┃ Mean (+)        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ tensordot full precision │ 1.670   │ 2.068   │ 1.987   │ 0.536 (3.1x)    │ 0.833 (2.5x)    │ 0.597 (3.3x)    │
│ tensordot half precision │ 1.781   │ 2.073   │ 2.013   │ 0.538 (3.3x)    │ 0.832 (2.5x)    │ 0.599 (3.4x)    │
│    einsum full precision │ 1.785   │ 2.076   │ 2.018   │ 2.076 (-1.2x)   │ 2.078 (-1.0x)   │ 2.077 (-1.0x)   │
│    einsum half precision │ 2.078   │ 2.079   │ 2.078   │ 2.078 (-1.0x)   │ 2.079 (-1.0x)   │ 2.079 (-1.0x)   │
│    matmul full precision │ 2.078   │ 2.079   │ 2.079   │ 0.528 (3.9x)    │ 0.826 (2.5x)    │ 0.594 (3.5x)    │
│    matmul half precision │ 1.787   │ 2.082   │ 2.022   │ 0.529 (3.4x)    │ 0.828 (2.5x)    │ 0.594 (3.4x)    │
└──────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘

Looks like tensordot is strictly better than einsum and is equivalent to matmul where matmul is possible. Given that it also benefits from 16 bit float calculations, it should be the preferred choice.

The next step is to replace the einsums in Dense with tensordots and measure performance.