bclarkson-code / Tricycle

Autograd to GPT-2 completely from scratch
104 stars 7 forks source link

Optimised GPU kernels #71

Open bclarkson-code opened 2 months ago

bclarkson-code commented 2 months ago

Andrej Karpathy has just ~upstaged me~ released llm.c which contains some highly optimised CUDA kernels. If we include these into tricycle, we can probably get a significant performance boost for operations like attention.

bclarkson-code commented 1 month ago

Cupy has support for custom kernels: https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-modules It should be possible to use andrej's flash attention kernel and import it into the MultiHeadSelfAttention block.

I think this should be implemented as an optional flag because I like how the current implementation is pretty easy to follow and a custom kernel would be harder for python users

bclarkson-code commented 1 month ago

I added llm.c kernels for ReLU and GeLU (the easiest ones to start with). I ran the following benchmark with richbench:

from tricycle.activation import GeLU, CudaGeLU, ReLU, CudaReLU

import numpy as np
from tricycle.tensor import Tensor

N_LOOPS = 1_000
INPUT_SHAPE = (32, int(2**15))

def bench_vanilla_relu():
    np.random.seed(0)
    tensor = Tensor((np.random.random(INPUT_SHAPE) * 2 - 1))
    tensor.to_gpu()
    layer = ReLU()
    for _ in range(N_LOOPS):
        output = layer(tensor)
        output.backward()

def bench_cuda_relu():
    np.random.seed(0)
    tensor = Tensor((np.random.random(INPUT_SHAPE) * 2 - 1))
    tensor.to_gpu()
    layer = CudaReLU()
    for _ in range(N_LOOPS):
        output = layer(tensor)
        output.backward()

def bench_vanilla_gelu():
    np.random.seed(0)
    tensor = Tensor((np.random.random(INPUT_SHAPE) * 2 - 1))
    tensor.to_gpu()
    layer = GeLU()
    for _ in range(N_LOOPS):
        output = layer(tensor)
        output.backward()

def bench_cuda_gelu():
    np.random.seed(0)
    tensor = Tensor((np.random.random(INPUT_SHAPE) * 2 - 1))
    tensor.to_gpu()
    layer = CudaGeLU()
    for _ in range(N_LOOPS):
        output = layer(tensor)
        output.backward()

__benchmarks__ = [
    (bench_vanilla_gelu, bench_cuda_gelu, "handcraft kernel for gelu"),
    (bench_vanilla_relu, bench_cuda_relu, "handcraft kernel for relu"),
]

This resulted in:

$ richbench benchmarks/
                                         Benchmarks, repeat=5, number=5
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃                 Benchmark ┃ Min     ┃ Max     ┃ Mean    ┃ Min (+)         ┃ Max (+)         ┃ Mean (+)        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ handcraft kernel for gelu │ 1.895   │ 1.907   │ 1.904   │ 0.521 (3.6x)    │ 0.524 (3.6x)    │ 0.522 (3.6x)    │
│ handcraft kernel for relu │ 0.677   │ 0.684   │ 0.680   │ 0.551 (1.2x)    │ 0.553 (1.2x)    │ 0.552 (1.2x)    │
└───────────────────────────┴─────────┴─────────┴─────────┴─────────────────┴─────────────────┴─────────────────┘

Looks like hand-crafted kernels are really fast! If this trend continues for attention and dense layers then we could see a huge speedup.