triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.78k stars 1.54k forks source link

[Help wanted]How to profile my slow matmul kernel with split_k strategy? #4397

Closed sleepwalker2017 closed 1 month ago

sleepwalker2017 commented 1 month ago

I'm generating code for matmul (48, 5120, 13824). I modify the 03-matrix-matmul.py in the tutorial to support split k strategy.

But I find the performance is much slower than before with the same config and split_k = 1.

I use the following config for the split k kernel: config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4, "split_k": 1}: , and I use this for the matmul kernel. config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4}:

they act the same since split_k is set to 1, but the performance is 3.x slower. I'm quite confused about this.

I can see the memory pattern has changed, but why does this change happen??

original kernel image

my split k kernel image

Could anyone give some advice? Thank you!

This is the slow split k kernel:

@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        a_scale_ptr, w_scale_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, 
        BLOCK_SIZE_N: tl.constexpr, 
        BLOCK_SIZE_K: tl.constexpr,
        split_k: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    # ----------------------------------------------------------
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

    pid_k = tl.program_id(axis=1) 
    # num_blocks_k = tl.cdiv(K, BLOCK_SIZE_K * split_k)

    total_block_num_k = tl.cdiv(K, BLOCK_SIZE_K)

    blocks_per_split = total_block_num_k // split_k

    offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * blocks_per_split * BLOCK_SIZE_K

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    if pid_k == split_k - 1:
        remaining_k = K - (split_k - pid_k - 1) * blocks_per_split * BLOCK_SIZE_K
        num_blocks_k = total_block_num_k - (split_k -1) * blocks_per_split
    else:
        remaining_k = (pid_k + 1) * blocks_per_split * BLOCK_SIZE_K
        num_blocks_k = blocks_per_split

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    a_scale = tl.load(a_scale_ptr)
    w_scale = tl.load(w_scale_ptr)
    for k in range(0, num_blocks_k):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < remaining_k - k * BLOCK_SIZE_K, other=0.0)

        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    accumulator = accumulator * a_scale * w_scale
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    #tl.store(c_ptrs, c, mask=c_mask)
    tl.atomic_add(c_ptrs, c, mask=c_mask)

the whole file is here: split_k.py

import torch
import json
import triton
import triton.language as tl

def cdiv(a, b):
    return (a + b - 1) // b

@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        a_scale_ptr, w_scale_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, 
        BLOCK_SIZE_N: tl.constexpr, 
        BLOCK_SIZE_K: tl.constexpr,
        split_k: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # 8 rows
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # 4 cols
    # pid = 8, we want pid_m = 0, pid_n = 1
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    # ----------------------------------------------------------
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

    pid_k = tl.program_id(axis=1)    #    0
    # num_blocks_k = tl.cdiv(K, BLOCK_SIZE_K * split_k)

    total_block_num_k = tl.cdiv(K, BLOCK_SIZE_K)

    if total_block_num_k % split_k == 0:
        # each split process equal amount of data
        blocks_per_split = total_block_num_k // split_k
    else:
        blocks_per_split = total_block_num_k // split_k

    offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * blocks_per_split * BLOCK_SIZE_K

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    if pid_k == split_k - 1:
        remaining_k = K - (split_k - pid_k - 1) * blocks_per_split * BLOCK_SIZE_K
        num_blocks_k = total_block_num_k - (split_k -1) * blocks_per_split
    else:
        remaining_k = (pid_k + 1) * blocks_per_split * BLOCK_SIZE_K
        num_blocks_k = blocks_per_split

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    a_scale = tl.load(a_scale_ptr)
    w_scale = tl.load(w_scale_ptr)
    for k in range(0, num_blocks_k):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < remaining_k - k * BLOCK_SIZE_K, other=0.0)

        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    accumulator = accumulator * a_scale * w_scale
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    #tl.store(c_ptrs, c, mask=c_mask)
    tl.atomic_add(c_ptrs, c, mask=c_mask)

def benchmark(a, b, d, a_scale, w_scale, config):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['split_k'])

    matmul_kernel[grid](
        a, b, d,  #
        a_scale, w_scale,
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        d.stride(0), d.stride(1),  #
        **config
    )
    # print(grid(config))
    for i in range(5):
        matmul_kernel[grid](
            a, b, d,  #
            a_scale, w_scale,
            M, N, K,  #
            a.stride(0), a.stride(1),  #
            b.stride(0), b.stride(1),  #
            d.stride(0), d.stride(1),  #
            **config
        )
    return 1
    c = torch.zeros_like(d)
    cnt = 10
    import time
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        for i in range(cnt):
            matmul_kernel[grid](
                a, b, c,  #
                a_scale, w_scale,
                M, N, K,  #
                a.stride(0), a.stride(1),  #
                b.stride(0), b.stride(1),  #
                c.stride(0), c.stride(1),  #
                **config
            )
    torch.cuda.synchronize()
    start = time.time()
    for i in range(10):
        g.replay()
    torch.cuda.synchronize()
    end = time.time()
    torch.cuda.nvtx.range_pop()
    return 1000 * (end - start)/cnt/10

def generate_config():
    configs = []

    block_size_m = 32 # 16 is min size
    while block_size_m <= 128:
        block_size_n = 32
        while block_size_n < 512:
            block_size_k = 32
            while block_size_k < 512:
                num_stages = 2
                while num_stages <= 6:
                    if (block_size_m * block_size_k + block_size_k * block_size_n) * (num_stages) + block_size_n * block_size_m > 116224:
                        break
                    num_warps = 2
                    while num_warps < 32:
                        split_k = 1
                        while split_k <= 16:
                            config = {'BLOCK_SIZE_M': block_size_m,
                                    'BLOCK_SIZE_N':block_size_n,
                                    'BLOCK_SIZE_K':block_size_k,
                                    'num_stages':num_stages,
                                    'num_warps':num_warps,
                                    'split_k': split_k
                                    }
                            configs.append(config)
                            split_k *= 2
                        num_warps *= 2
                    num_stages += 1
                block_size_k *= 2
            block_size_n *= 2
        block_size_m *= 2
    return configs

def torch_fp8(a, b, a_scale, w_scale):
    import time
    cnt = 10
    for i in range(5):
        ret, _ = torch._scaled_mm(
                    a,
                    b,
                    scale_a = a_scale,
                    scale_b = w_scale,
                    out_dtype=torch.float16,
                )
    torch.cuda.synchronize()
    start = time.time()
    for i in range(cnt):
        ret, _ = torch._scaled_mm(
                    a,
                    b,
                    scale_a = a_scale,
                    scale_b = w_scale,
                    out_dtype=torch.float16,
                )
    torch.cuda.synchronize()
    end = time.time()
    duration = 1000 * (end - start)/cnt
    return duration, ret

def cutlass_fp8(qinput, weight, x_scale, weight_scale):
    import time
    cnt = 10
    for i in range(5):
        ret = ops.cutlass_scaled_mm(qinput,
                               weight,
                               out_dtype=torch.float16,
                               scale_a=x_scale,
                               scale_b=weight_scale)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(cnt):
        ret = ops.cutlass_scaled_mm(qinput,
                               weight,
                               out_dtype=torch.float16,
                               scale_a=x_scale,
                               scale_b=weight_scale)
    torch.cuda.synchronize()
    end = time.time()
    duration = 1000 * (end - start)/cnt
    return duration, ret

import torch.nn.functional as F
def compare_result(torch_output, triton_output):
    print(torch_output)
    print(triton_output)
    torch_output = torch_output.to(torch.float32)
    triton_output = triton_output.to(torch.float32)
    diff = torch.abs(triton_output - torch_output)
    relative_diff = torch.abs(diff / torch_output)
    idx = relative_diff.argmax()
    print('diff avg max min', "%.4f"%diff.mean().item(), "%.4f"%diff.max().item(), "%.4f"%diff.min().item())
    print('relative diff avg max min', "%.4f"%relative_diff.mean().item(), "%.4f"%relative_diff.max().item(), "%.4f"%relative_diff.min().item())
    cos_sim = F.cosine_similarity(torch_output.reshape(-1),
                                  triton_output.reshape(-1), dim=0)
    print("cos_sim", cos_sim.item())
    return cos_sim.item(), relative_diff.max().item()

def check_consistency(M, N, K):
    torch.manual_seed(0)
    dtype=torch.float8_e4m3fn
    a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
    #a = torch.ones((M, K), device='cuda', dtype=torch.float16).to(dtype)
    b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
    #b = torch.ones((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
    c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
    a_scale = torch.randn((), device='cuda', dtype=torch.float32)
    w_scale = torch.randn((), device='cuda', dtype=torch.float32)

    configs = generate_config()
    configs.reverse()
    print('begin to tune', M, N, K)
    best_cost = 1000
    best_config = None

    # cublas_cost, d = torch_fp8(a, b)
    # d = torch.matmul(a, b)
    cublas_cost = 10
    print('cublas cost', '%.4f'%cublas_cost)
    for i, config in enumerate(configs):
        print(config)
        try:
            torch.manual_seed(i)
            dtype=torch.float16
            dtype=torch.float8_e4m3fn
            a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
            b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
            c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
            time_cost = benchmark(a, b, c, a_scale, w_scale, config)
            # d = torch.matmul(a, b)
            _, d = torch_fp8(a, b, a_scale, w_scale)
        except Exception as ex:
            print(ex)
            continue

        if time_cost < best_cost:
            best_config = config
            best_cost = time_cost
        print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)
        cos_sim, max_diff = compare_result(d, c)
        if i == 30:
            break
        '''
        '''
    compare_result(d, c)
    print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost), cos_sim)
    # print("dff for ", M, N, K, "%.6f"%cos_sim, "%.6f"%max_diff, cos_sim > 0.9999)
    if cos_sim < 0.9999:
        print("fuck!!!!")
    return best_config, best_cost

def tune_gemm(M, N, K):
    torch.manual_seed(0)
    dtype=torch.float8_e4m3fn
    a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
    b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
    a_scale = torch.randn((), device='cuda', dtype=torch.float32)
    w_scale = torch.randn((), device='cuda', dtype=torch.float32)

    configs = generate_config()
    configs.reverse()
    print('begin to tune', M, N, K)
    best_cost = 1000
    best_config = None

    cublas_cost, d = torch_fp8(a, b, a_scale, w_scale)
    print('cublas cost', '%.4f'%cublas_cost)
    for i, config in enumerate(configs):
        if config != {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4, "split_k": 1}:
            continue
        try:
            c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
            time_cost = benchmark(a, b, c, a_scale, w_scale, config)
        except Exception as ex:
            # print(ex)
            continue

        if time_cost < best_cost:
            best_config = config
            best_cost = time_cost
        print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)

    c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
    time_cost = benchmark(a, b, c, a_scale, w_scale, config)
    cos_sim, max_diff = compare_result(d, c)
    print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost))
    if cos_sim < 0.9999:
        print("fuck!!!!")
    return best_config, best_cost

def tune_random():
    import random
    for i in range(1):
        M = random.randint(512//16,10240//8) * 16
        N = random.randint(512//16,10240//8) * 16
        K = random.randint(512//16,10240//8) * 16
        #tune_gemm(5133, 513, 511 + 32)
        check_consistency(M, N, K)
    exit(0)

import sys
if __name__ == '__main__':
    # tune_random()
    tune_gemm(48, 5120, 13824)
    exit(0)
    result = {}
    # n_k_list = [(15360, 5120), (5120, 5120), (5120, 13824), (27648, 5120)]
    n_k_list = [(5120, 13824)]
    for n_k in n_k_list:
        N, K = n_k
        for i in range(8, 257, 8):
            best_config, best_cost = tune_gemm(i, N, K)
            result[i] = (best_config, best_cost)
        import json
        with open(f'best_config_{N}_{K}.json', 'a+') as f:
            json.dump(result, f)

column.py

import torch

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        a_scale_ptr, w_scale_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # 8 rows
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # 4 cols
    # pid = 8, we want pid_m = 0, pid_n = 1
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    # print(M, BLOCK_SIZE_M)
    '''
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    '''
    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetic` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
    a_scale = tl.load(a_scale_ptr)
    w_scale = tl.load(w_scale_ptr)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        #accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float16)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    accumulator = accumulator * a_scale * w_scale
    c = accumulator.to(tl.float16)
    #c = accumulator

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def benchmark(a, b, c, a_scale, w_scale, config):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )

    for i in range(6):
        matmul_kernel[grid](
            a, b, c,  #
            a_scale, w_scale,
            M, N, K,  #
            a.stride(0), a.stride(1),  #
            b.stride(0), b.stride(1),  #
            c.stride(0), c.stride(1),  #
            **config
        )
    return 1
    cnt = 10
    import time
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        for i in range(cnt):
            matmul_kernel[grid](
                a, b, c,  #
                a_scale, w_scale,
                M, N, K,  #
                a.stride(0), a.stride(1),  #
                b.stride(0), b.stride(1),  #
                c.stride(0), c.stride(1),  #
                **config
            )
    torch.cuda.synchronize()
    start = time.time()
    for i in range(10):
        g.replay()
    torch.cuda.synchronize()
    end = time.time()
    torch.cuda.nvtx.range_pop()
    return 1000 * (end - start)/cnt/10

def generate_config():
    configs = []

    block_size_m = 32
    while block_size_m < 512:
        block_size_n = 32
        while block_size_n < 512:
            block_size_k = 32
            while block_size_k < 512:
                num_stages = 2
                while num_stages < 7:
                    if (block_size_m * block_size_k + block_size_k * block_size_n) * (num_stages) + block_size_n * block_size_m > 116224:
                        break
                    num_warps = 2
                    while num_warps < 32:
                        config = {'BLOCK_SIZE_M': block_size_m,
                                'BLOCK_SIZE_N':block_size_n,
                                'BLOCK_SIZE_K':block_size_k,
                                'num_stages':num_stages,
                                'num_warps':num_warps
                                }
                        configs.append(config)
                        num_warps *= 2
                    num_stages += 1
                block_size_k *= 2
            block_size_n *= 2
        block_size_m *= 2
    return configs

import json
start_config = {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'num_stages': 3, 'num_warps': 2}
target_config = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'num_stages': 3, 'num_warps': 8}

def torch_fp8(a, b, a_scale, w_scale):
    import time
    cnt = 10
    for i in range(5):
        ret, _ = torch._scaled_mm(
                    a,
                    b,
                    scale_a = a_scale,
                    scale_b = w_scale,
                    out_dtype=torch.float16,
                )
    torch.cuda.synchronize()
    start = time.time()
    for i in range(cnt):
        ret, _ = torch._scaled_mm(
                    a,
                    b,
                    out_dtype=torch.float16,
                )
    torch.cuda.synchronize()
    end = time.time()
    duration = 1000 * (end - start)/cnt
    return duration, ret

import torch.nn.functional as F

# b is right output
def compare_result(torch_output, triton_output):
    torch_output = torch_output.to(torch.float32)
    triton_output = triton_output.to(torch.float32)
    diff = torch.abs(triton_output - torch_output)
    relative_diff = torch.abs(diff / torch_output)
    idx = relative_diff.argmax()
    print('abs diff avg max min', diff.mean().item(), diff.max().item(), diff.min().item())
    print('relative diff avg max min', relative_diff.mean().item(), relative_diff.max().item(), relative_diff.min().item())
    print("right:\n", torch_output)
    print("right:\n", triton_output)
    # print('triton', triton_output[row][col].item())
    cos_sim = F.cosine_similarity(torch_output.to(torch.float32).reshape(-1),
                                  triton_output.to(torch.float32).reshape(-1), dim=0)
    print("cos_sim", cos_sim.item())

def tune_gemm(M, N, K):
    torch.manual_seed(0)
    dtype=torch.float16
    dtype=torch.float8_e4m3fn
    a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
    #a = torch.ones((M, K), device='cuda', dtype=torch.float16).to(dtype)
    b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
    #b = torch.ones((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
    c = torch.randn((M, N), device='cuda', dtype=torch.float16)

    a_scale = torch.ones(1, device='cuda', dtype=torch.float32)
    w_scale = torch.ones(1, device='cuda', dtype=torch.float32)
    configs = generate_config()
    configs.reverse()
    print('begin to tune', M, N, K)
    best_cost = 1000
    best_config = None

    cublas_cost, d = torch_fp8(a, b, a_scale, w_scale)
    print('cublas cost', '%.4f'%cublas_cost)
    for i, config in enumerate(configs):
        if config != {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4}:
            continue
        try:
            time_cost = benchmark(a, b, c, a_scale, w_scale, config)
        except Exception as ex:
            continue

        if time_cost < best_cost:
            best_config = config
            best_cost = time_cost
        print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)
        #break
    time_cost = benchmark(a, b, c, a_scale, w_scale, best_config)
    compare_result(d, c)
    print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost))
    return best_config, best_cost

import sys
if __name__ == '__main__':
    tune_gemm(48, 5120, 13824)
    exit(0)
    result = {}
    n_k_list = [(15360, 5120), (5120, 5120), (5120, 13824), (27648, 5120)]
    for n_k in n_k_list:
        N, K = n_k
        for i in range(8, 257, 8):
            best_config, best_cost = tune_gemm(i, N, K)
            result[i] = (best_config, best_cost)
        import json
        with open(f'best_config_{N}_{K}.json', 'a+') as f:
            json.dump(result, f)

Here is my whole file, Could anyone give some advice?

Thank you!

sleepwalker2017 commented 1 month ago

OMG, it seems the triton version is too old. I use triton 0.2.3 and the performance is bad. but when I use latest triton, it's much faster.