fattorib / fusedswiglu

Fused SwiGLU Triton kernels
MIT License
2 stars 0 forks source link

Wall clock speed is slower than Pytorch primitives #1

Open rustic-snob opened 2 weeks ago

rustic-snob commented 2 weeks ago

Hi! Thank you for your amazing work!

I'm having some trouble on comparing the fused swiglu kernel with its plain pytorch version.

I checked the wall clock time with code below, and it gives me like x 0.5 speed compared to the pytorch one.

import torch
import torch.nn as nn

from kernels.kernels_bf16 import fused_swiglu_fwd as fused_swiglu_fwd_bf16
from kernels.kernels_fp16 import fused_swiglu_fwd as fused_swiglu_fwd_fp16

silu = nn.SiLU()

def swiglu(x, w_gate, w_up):
    return (silu(x @ w_gate)) * (x @ w_up)

d_model=4096
d_intermediate=14336
bs = 4
sq = 128

w_gate = torch.randn((d_model, d_intermediate), dtype = torch.bfloat16, device = 'cuda')
w_up = torch.randn((d_model, d_intermediate), dtype = torch.bfloat16, device = 'cuda')
x = torch.randn((bs, sq, d_model), dtype = torch.bfloat16, device = 'cuda')

# Benchmark function
def benchmark(func, *args, num_runs=500):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(num_runs):
        result = func(*args)
        torch.cuda.synchronize()
    end.record()

    torch.cuda.synchronize()
    return start.elapsed_time(end) / num_runs  # Average time per run in milliseconds

with torch.no_grad():
    for _ in range(50):
        fused_swiglu_fwd_bf16(x, w_gate, w_up)
        swiglu(x, w_gate, w_up)

    # Benchmark
    time_fused = benchmark(fused_swiglu_fwd_bf16, x, w_gate, w_up)
    time_pytorch = benchmark(swiglu, x, w_gate, w_up)

    print(f"Fused: {time_fused:.2f} ms")
    print(f"Pytorch: {time_pytorch:.2f} ms")

    # Verify correctness
    pytorch_result = swiglu(x, w_gate, w_up)
    fused_result = fused_swiglu_fwd_bf16(x, w_gate, w_up)

    print(f"Max difference: {(fused_result[0] - pytorch_result).abs().max().item()}")

Did I do something wrong with this?

Thanks.

fattorib commented 2 weeks ago

Thanks for the interest and providing code. I was able to replicate this. The problem size you are using (short sequence length and large inner dimension) is different than what the kernel was tuned for - all of my testing was done with longer contexts of 2048 tokens and smaller model dimensions. You might have to play around with adding additional autotune settings here to get better performance.