FMInference / DejaVu

268 stars 32 forks source link

In fp16, the sparse kernel is slower than PyTorch dense gemm #21

Closed sleepcoo closed 6 months ago

sleepcoo commented 6 months ago

This is the result of the test

Operation Compute Type Triton Cost (ms) Torch Cost (ms)
gather_gemv torch.float32 0.09 0.32
gather_gemv torch.float16 0.09 0.16
gather_transposed_gemv torch.float32 0.10 0.33
gather_transposed_gemv torch.float16 0.12 0.15
mlp_sparse torch.float32 0.34 0.63
mlp_sparse torch.float16 0.36 0.32

Suspicious

I suspect that the issue might be due to the triton kernel hardcoding the multiplication type, for example: acc0 += tl.sum(a.to(tl.float32) * x0.to(tl.float32)[None, :], 1). However, when I tried changing float32 to float16, there was no change.

In summary

Whether it's gather_gemv, gather_transposed_gemv, or mlp_sparse, there is no improvement compared to the native torch computation of dense gemm. Was the improvement mentioned in the paper observed under the condition of using fp32?

Test code


def mlp_sparse(x, gate_w, up_w, down_w, idx):
    gate_x = gather_gemv(x, gate_w, idx, None)
    up = gather_gemv(x, up_w, idx, None)
    x = torch.nn.functional.silu(gate_x) * up
    x = gather_transposed_gemv(x, down_w, idx, None)
    return x
def benchmark(func, N_REPEATS):
    import time
    for _ in range(N_REPEATS):
        func()
    t1 = time.time()
    torch.cuda.synchronize()
    for _ in range(N_REPEATS):
        func()
    torch.cuda.synchronize()
    t2 = time.time()
    return (t2 - t1) / N_REPEATS * 1000
def run_benchmark(dtype,batch_size,hidden_size,intermediate_size,preserve_neurons,N_REPEATS):
    from torch import nn
    from einops import rearrange
    import torch.cuda.nvtx as nvtx
    print(f"compute type is:  {str(dtype)} ")
    x = torch.randn(batch_size, hidden_size, device='cuda').type(dtype)
    weight = torch.randn(intermediate_size, hidden_size, device='cuda').type(dtype)
    weight_b = torch.randn(hidden_size, intermediate_size, device='cuda').type(dtype)
    idx = torch.arange(0, preserve_neurons, dtype=torch.int32, device='cuda')
    down_weight = torch.randn(hidden_size, intermediate_size, device='cuda').type(dtype)
    down_weight_t = down_weight.t().contiguous().to('cuda')
    #warm up 
    gather_gemv(x, weight, idx)
    gather_transposed_gemv(gather_gemv(x, weight, idx), down_weight, idx, None)
    torch.matmul(torch.matmul(x, weight_b), down_weight_t)
    torch.cuda.synchronize()

    # Benchmark gather_gemv
    nvtx.range_push("gather_gemv")
    triton_cost = benchmark(lambda: gather_gemv(x, weight, idx), N_REPEATS)
    nvtx.range_pop()
    nvtx.range_push("matmul")
    torch_cost = benchmark(lambda: torch.matmul(x, weight_b), N_REPEATS)
    nvtx.range_pop()
    print(f"gather_gemv triton_cost: {triton_cost:.2f} ms, torch_cost: {torch_cost:.2f} ms")

    # Benchmark gather_transposed_gemv
    output = gather_gemv(x, weight, idx)
    result = torch.matmul(x, weight_b)
    triton_cost = benchmark(lambda: gather_transposed_gemv(output, down_weight, idx, None), N_REPEATS)
    torch_cost = benchmark(lambda: torch.matmul(result, down_weight_t), N_REPEATS)
    print(f"gather_transposed_gemv triton_cost: {triton_cost:.2f} ms, torch_cost: {torch_cost:.2f} ms")
    # Benchmark mlp_sparse
    x = torch.randn(batch_size,1, hidden_size, device='cuda').type(dtype)
    gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False,).cuda().type(dtype)
    up_proj = nn.Linear(hidden_size, intermediate_size, bias=False).cuda().type(dtype)
    down_proj = nn.Linear(intermediate_size, hidden_size, bias=False).cuda().type(dtype)
    triton_cost = benchmark(lambda: mlp_sparse(rearrange(x, "b 1 d -> b d"), weight, weight, down_weight, idx), N_REPEATS)
    torch_cost = benchmark(lambda: down_proj(torch.nn.functional.silu(gate_proj(x)) * up_proj(x)), N_REPEATS)
    print(f"mlp_sparse  triton_cost: {triton_cost:.2f} ms, torch_cost: {torch_cost:.2f} ms")

if __name__ == "__main__":
    batch_size = 1
    hidden_size = 5120
    intermediate_size = 13824
    preserve_neurons = 1
    N_REPEATS = 1000
    run_benchmark(torch.float32,batch_size,hidden_size,intermediate_size,preserve_neurons,N_REPEATS)
    run_benchmark(torch.float16,batch_size,hidden_size,intermediate_size,preserve_neurons,N_REPEATS)