triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.7k stars 1.53k forks source link

[Performance] The implementation of make_block_ptr matmul appears to have worse performance than the official tutorial implementation. #4424

Open tlogn opened 1 month ago

tlogn commented 1 month ago

Hi everyone! I implement matmul with make_block_ptr, but perform worse than official tutorial example. At first, I think it must be caused by the L2 cache optimization. After I apply pid_m and pid_n with group_size, it doesn't work. Could you please help analysis? Below is my triton code.

env: A100-80G-SXM, triton==2.3.1, torch== 2.3.1


import triton
import triton.language as tl

import torch

def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
    ]
@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def _block_matmul_kernel(a_ptr, b_ptr, o_ptr, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    # pid_m = tl.program_id(0)
    # pid_n = tl.program_id(1)

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(K, 1), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
    b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(1, K), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for i in range(0, tl.cdiv(K, BLOCK_K)):
        a_value = tl.load(a_tile, boundary_check=(0, 1), padding_option="zero")
        b_value = tl.load(b_tile, boundary_check=(0, 1), padding_option="zero")
        accumulator += tl.dot(a_value, b_value)
        a_tile = tl.advance(a_tile, (0, BLOCK_K))
        b_tile = tl.advance(b_tile, (BLOCK_K, 0))

    o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
    tl.store(o_tile, accumulator.to(o_tile.dtype.element_ty), boundary_check=(0, 1))

def matmul_block(a, b):
    M, K = a.shape
    N = b.shape[1]
    o = torch.zeros(M, N, device=a.device, dtype=a.dtype)
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), )
    _block_matmul_kernel[grid](a, b, o, M, N, K)
    return o

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def _matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
        ACTIVATION: tl.constexpr  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
    # `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
    # See above `Pointer Arithmetic` section for details
    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_M, BLOCK_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
    _matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c

M, N, K = 513, 513, 64
a = torch.randn((M, K), device='cuda', dtype=torch.bfloat16)
b = torch.randn((N, K), device='cuda', dtype=torch.bfloat16)
b = b.transpose(0, 1)
o_triton = matmul_block(a, b)
o_triton_2 = matmul(a, b)
o_torch = a @ b

print(torch.equal(o_triton, o_torch))
print(torch.max(abs(o_torch-o_triton)))

@triton.testing.perf_report([
    triton.testing.Benchmark(
            x_names=["M", "N", "K"],  # Argument names to use as an x-axis for the plot
            x_vals=[128 * i for i in range(2, 33)],  # Different possible values for `x_name`
            line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
            # Possible values for `line_arg`
            # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
            line_vals=["torch", "triton", "triton_block"],  # Label name for the lines
            line_names=["torch", "Triton", "triton_block"],  # Line styles
            styles=[("green", "-"), ("blue", "-"), ("yellow", "-")],
            ylabel="TFLOPS",  # Label name for the y-axis
            plot_name="matmul-performance-" +
            ("fp16"),  # Name for the plot, used also as a file name for saving the plot.
            args={},
        )])
def benchmark(M, N, K, provider):
    a = torch.randn((M, K), device='cuda', dtype=torch.float16)
    b = torch.randn((K, N), device='cuda', dtype=torch.float16)
    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
    if provider == 'triton_block':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_block(a, b), quantiles=quantiles)
    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True)

Below is result.

True
tensor(0., device='cuda:0', dtype=torch.bfloat16)
matmul-performance-fp16:
         M       N       K       torch      Triton  triton_block
0    256.0   256.0   256.0    3.698681    3.799188      2.744963
1    384.0   384.0   384.0   10.347789   11.379241      8.891819
2    512.0   512.0   512.0   23.899168   21.024080     16.946683
3    640.0   640.0   640.0   40.454321   37.491990     31.326958
4    768.0   768.0   768.0   63.195428   55.567321     47.225272
5    896.0   896.0   896.0   74.804821   77.781484     67.911930
6   1024.0  1024.0  1024.0  106.184915   97.048245     83.991072
7   1152.0  1152.0  1152.0  130.534822  125.395657    108.828575
8   1280.0  1280.0  1280.0  161.617755  164.456710    137.680676
9   1408.0  1408.0  1408.0  152.563914  131.814763    111.831302
10  1536.0  1536.0  1536.0  180.904487  158.608129    136.605795
11  1664.0  1664.0  1664.0  181.452545  184.003307    162.508570
12  1792.0  1792.0  1792.0  174.085956  215.495245    189.694920
13  1920.0  1920.0  1920.0  207.004214  171.926931    156.258571
14  2048.0  2048.0  2048.0  231.310175  197.306473    178.778190
15  2176.0  2176.0  2176.0  219.780459  220.080913    198.629468
16  2304.0  2304.0  2304.0  241.902495  243.056252    217.595192
17  2432.0  2432.0  2432.0  216.684216  212.635486    194.551578
18  2560.0  2560.0  2560.0  237.664548  233.978801    212.563548
19  2688.0  2688.0  2688.0  209.321911  209.267784    192.584139
20  2816.0  2816.0  2816.0  225.907195  226.236767    208.151326
21  2944.0  2944.0  2944.0  241.464607  244.182053    225.311253
22  3072.0  3072.0  3072.0  225.758697  227.516236    209.424339
23  3200.0  3200.0  3200.0  237.586995  239.644268    222.657106
24  3328.0  3328.0  3328.0  224.578026  226.788892    211.185902
25  3456.0  3456.0  3456.0  238.779219  239.477407    223.715758
26  3584.0  3584.0  3584.0  243.322838  231.052159    215.850896
27  3712.0  3712.0  3712.0  230.261107  238.081114    226.452386
28  3840.0  3840.0  3840.0  230.798191  229.846332    217.493410
29  3968.0  3968.0  3968.0  229.516598  235.468015    226.823459
30  4096.0  4096.0  4096.0  239.514130  238.516526    223.847781
CODEJIN commented 1 month ago

Dear @tlogn ,

Hello, I checked your code and found an issue in one part.

a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(K, 1), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(1, K), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

In this section, the strides are set to (K, 1), (1, K), and (N, 1), respectively. However, the strides need to be calculated and input separately. It should be adjusted to (stride_am, stride_ak), (stride_bk, stride_bn), and (stride_cm, stride_cn) as in the tutorial code. When I modified and checked this part, the speed changed to be almost similar to the tutorial, as shown below. 'Triton_Block_Stride1' is yours.

         M       N       K       Torch      Triton  Triton_Block  Triton_Block_Stride1
0    256.0   256.0   256.0    4.096000    4.096000      4.096000              0.780190
1    384.0   384.0   384.0   13.824000   12.288000     12.288000             11.059200
2    512.0   512.0   512.0   26.214401   23.831273     23.831273             21.845333
3    640.0   640.0   640.0   42.666665   36.571428     36.571428             34.133334
4    768.0   768.0   768.0   52.043293   55.296000     52.043293             44.236801
5    896.0   896.0   896.0   63.860363   73.943582     70.246402             61.000942
6   1024.0  1024.0  1024.0   83.886082   99.864382     95.325090             74.731472
7   1152.0  1152.0  1152.0   85.313826   87.823057     85.313826             87.823057
8   1280.0  1280.0  1280.0   85.333330  107.789478    107.789478             93.090908
9   1408.0  1408.0  1408.0  111.260738  132.970149    132.970149            109.035523
10  1536.0  1536.0  1536.0  110.592000  115.971541    112.347429             99.688560
11  1664.0  1664.0  1664.0  124.984884  134.312118    132.336939            116.868992
12  1792.0  1792.0  1792.0  119.568337  122.167649    120.854018            118.309723
13  1920.0  1920.0  1920.0  138.196817  139.636368    139.636368            132.923078
14  2048.0  2048.0  2048.0  156.796411  158.275623    156.796411            149.796569
15  2176.0  2176.0  2176.0  138.783781  146.887946    144.774450            128.997748
16  2304.0  2304.0  2304.0  137.286620  139.695152    138.882977            131.977196
17  2432.0  2432.0  2432.0  153.521664  154.365184    154.365184            138.396372
18  2560.0  2560.0  2560.0  129.007867  148.945453    146.941707            137.680676
19  2688.0  2688.0  2688.0  142.071373  162.802816    161.417260            146.459680
20  2816.0  2816.0  2816.0  155.765024  158.022489    158.022489            150.393823
21  2944.0  2944.0  2944.0  115.628846  154.770282    153.341637            140.383190
22  3072.0  3072.0  3072.0  126.100589  167.523970    165.564625            152.212641
23  3200.0  3200.0  3200.0  136.460557  164.102564    164.948460            150.234737
24  3328.0  3328.0  3328.0  147.523150  162.142557    160.694855            148.435663
25  3456.0  3456.0  3456.0  159.331159  160.600739    160.921302            153.309371
26  3584.0  3584.0  3584.0  130.312159  159.707629    158.580935            147.161035
27  3712.0  3712.0  3712.0  138.939282  159.588392    160.091903            147.124220
28  3840.0  3840.0  3840.0  148.845220  159.815035    158.441257            147.455999
29  3968.0  3968.0  3968.0  158.872396  160.136403    160.136403            153.483193
30  4096.0  4096.0  4096.0  169.466833  170.111186    170.978001            162.897953
tlogn commented 1 month ago

Dear @tlogn ,

Hello, I checked your code and found an issue in one part.

a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(K, 1), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(1, K), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

In this section, the strides are set to (K, 1), (1, K), and (N, 1), respectively. However, the strides need to be calculated and input separately. It should be adjusted to (stride_am, stride_ak), (stride_bk, stride_bn), and (stride_cm, stride_cn) as in the tutorial code. When I modified and checked this part, the speed changed to be almost similar to the tutorial, as shown below. 'Triton_Block_Stride1' is yours.

         M       N       K       Torch      Triton  Triton_Block  Triton_Block_Stride1
0    256.0   256.0   256.0    4.096000    4.096000      4.096000              0.780190
1    384.0   384.0   384.0   13.824000   12.288000     12.288000             11.059200
2    512.0   512.0   512.0   26.214401   23.831273     23.831273             21.845333
3    640.0   640.0   640.0   42.666665   36.571428     36.571428             34.133334
4    768.0   768.0   768.0   52.043293   55.296000     52.043293             44.236801
5    896.0   896.0   896.0   63.860363   73.943582     70.246402             61.000942
6   1024.0  1024.0  1024.0   83.886082   99.864382     95.325090             74.731472
7   1152.0  1152.0  1152.0   85.313826   87.823057     85.313826             87.823057
8   1280.0  1280.0  1280.0   85.333330  107.789478    107.789478             93.090908
9   1408.0  1408.0  1408.0  111.260738  132.970149    132.970149            109.035523
10  1536.0  1536.0  1536.0  110.592000  115.971541    112.347429             99.688560
11  1664.0  1664.0  1664.0  124.984884  134.312118    132.336939            116.868992
12  1792.0  1792.0  1792.0  119.568337  122.167649    120.854018            118.309723
13  1920.0  1920.0  1920.0  138.196817  139.636368    139.636368            132.923078
14  2048.0  2048.0  2048.0  156.796411  158.275623    156.796411            149.796569
15  2176.0  2176.0  2176.0  138.783781  146.887946    144.774450            128.997748
16  2304.0  2304.0  2304.0  137.286620  139.695152    138.882977            131.977196
17  2432.0  2432.0  2432.0  153.521664  154.365184    154.365184            138.396372
18  2560.0  2560.0  2560.0  129.007867  148.945453    146.941707            137.680676
19  2688.0  2688.0  2688.0  142.071373  162.802816    161.417260            146.459680
20  2816.0  2816.0  2816.0  155.765024  158.022489    158.022489            150.393823
21  2944.0  2944.0  2944.0  115.628846  154.770282    153.341637            140.383190
22  3072.0  3072.0  3072.0  126.100589  167.523970    165.564625            152.212641
23  3200.0  3200.0  3200.0  136.460557  164.102564    164.948460            150.234737
24  3328.0  3328.0  3328.0  147.523150  162.142557    160.694855            148.435663
25  3456.0  3456.0  3456.0  159.331159  160.600739    160.921302            153.309371
26  3584.0  3584.0  3584.0  130.312159  159.707629    158.580935            147.161035
27  3712.0  3712.0  3712.0  138.939282  159.588392    160.091903            147.124220
28  3840.0  3840.0  3840.0  148.845220  159.815035    158.441257            147.455999
29  3968.0  3968.0  3968.0  158.872396  160.136403    160.136403            153.483193
30  4096.0  4096.0  4096.0  169.466833  170.111186    170.978001            162.897953

Dear @CODEJIN , I update my code as you said, but got even worse performance. Could you provide me with a clearer view of your code and environment, please? Here are the revised code and results.

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def _block_matmul_kernel(a_ptr, b_ptr, o_ptr,
                        M, N, K, 
                        stride_am, stride_ak, 
                        stride_bk, stride_bn, 
                        stride_cm, stride_cn,
                        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    # pid_m = tl.program_id(0)
    # pid_n = tl.program_id(1)

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
    b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for _ in range(0, tl.cdiv(K, BLOCK_K)):
        a_value = tl.load(a_tile, boundary_check=(0, 1), padding_option="zero")
        b_value = tl.load(b_tile, boundary_check=(0, 1), padding_option="zero")
        accumulator += tl.dot(a_value, b_value)
        a_tile = tl.advance(a_tile, (0, BLOCK_K))
        b_tile = tl.advance(b_tile, (BLOCK_K, 0))

    o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
    tl.store(o_tile, accumulator.to(o_tile.dtype.element_ty), boundary_check=(0, 1))

def matmul_block(a, b):
    M, K = a.shape
    N = b.shape[1]
    o = torch.zeros(M, N, device=a.device, dtype=a.dtype)
    grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), )
    _block_matmul_kernel[grid](a, b, o, 
                               M, N, K,
                                a.stride(0), a.stride(1),  #
                                b.stride(0), b.stride(1),  #
                                o.stride(0), o.stride(1),  #
)
    return o

triton version 2.3.1:

matmul-performance-fp16:
         M       N       K       torch      Triton  triton_block
0    256.0   256.0   256.0    3.771856    3.615779      1.849340
1    384.0   384.0   384.0   10.025337   11.059200      5.019779
2    512.0   512.0   512.0   23.629882   22.133530      9.927347
3    640.0   640.0   640.0   37.321185   35.121115     16.532796
4    768.0   768.0   768.0   58.254222   53.017888     24.661630
5    896.0   896.0   896.0   72.747082   75.813993     34.636130
6   1024.0  1024.0  1024.0  104.530941   95.122413     45.221606
7   1152.0  1152.0  1152.0  132.710398  127.913641     57.113858
8   1280.0  1280.0  1280.0  154.202351  162.017307     65.601599
9   1408.0  1408.0  1408.0  154.386578  130.435009     71.822489
10  1536.0  1536.0  1536.0  179.471019  156.742163     63.980909
11  1664.0  1664.0  1664.0  182.141161  179.865814     73.563727
12  1792.0  1792.0  1792.0  171.922348  213.956912     81.169391
13  1920.0  1920.0  1920.0  205.752551  172.261682     82.163447
14  2048.0  2048.0  2048.0  233.625284  199.136092     87.367114
15  2176.0  2176.0  2176.0  221.519344  222.437566     78.140610
16  2304.0  2304.0  2304.0  244.299100  246.108155     86.560062
17  2432.0  2432.0  2432.0  216.007410  214.820276     90.874642
18  2560.0  2560.0  2560.0  236.272205  233.120509     88.937745
19  2688.0  2688.0  2688.0  211.473485  212.845488     88.719322
20  2816.0  2816.0  2816.0  228.496180  228.571024     91.368556
21  2944.0  2944.0  2944.0  244.856905  246.122858     93.309519
22  3072.0  3072.0  3072.0  226.889478  227.730724     90.780797
23  3200.0  3200.0  3200.0  240.178245  244.435158     90.927254
24  3328.0  3328.0  3328.0  228.793481  230.073050     90.962703
25  3456.0  3456.0  3456.0  242.152268  241.653251     93.437041
26  3584.0  3584.0  3584.0  243.487567  230.459950     94.570013
27  3712.0  3712.0  3712.0  231.914907  243.244180     97.588762
28  3840.0  3840.0  3840.0  231.016646  232.825259     92.446485
29  3968.0  3968.0  3968.0  230.887293  241.370149     93.729380
30  4096.0  4096.0  4096.0  245.342590  237.632358     98.889472

triton version 3.0.0

matmul-performance-fp16:
         M       N       K       torch      Triton  triton_block
0    256.0   256.0   256.0    3.785473    3.718355      2.818753
1    384.0   384.0   384.0   10.011157   10.922667      8.466373
2    512.0   512.0   512.0   23.563505   22.310127     16.777215
3    640.0   640.0   640.0   39.290168   36.008791     29.627485
4    768.0   768.0   768.0   60.624308   55.404212     46.110018
5    896.0   896.0   896.0   73.101942   75.942051     64.967771
6   1024.0  1024.0  1024.0  103.723132   97.826335     77.492913
7   1152.0  1152.0  1152.0  129.561335  123.611238     99.740591
8   1280.0  1280.0  1280.0  158.108560  157.728042    126.030769
9   1408.0  1408.0  1408.0  154.113809  131.467095    116.537631
10  1536.0  1536.0  1536.0  175.575510  157.286398    137.936912
11  1664.0  1664.0  1664.0  178.749333  179.027156    153.417788
12  1792.0  1792.0  1792.0  171.799173  213.069643    179.920737
13  1920.0  1920.0  1920.0  207.004214  172.395948    152.751385
14  2048.0  2048.0  2048.0  231.011580  199.431987    171.524248
15  2176.0  2176.0  2176.0  221.671854  225.002361    191.425902
16  2304.0  2304.0  2304.0  243.909341  247.863785    206.040940
17  2432.0  2432.0  2432.0  218.740347  215.231714    187.374491
18  2560.0  2560.0  2560.0  242.165363  238.529575    203.173030
19  2688.0  2688.0  2688.0  212.957504  213.820285    185.832487
20  2816.0  2816.0  2816.0  228.048145  228.664652    196.820563
21  2944.0  2944.0  2944.0  245.026188  246.065880    189.625821
22  3072.0  3072.0  3072.0  227.816589  229.868609    202.927459
23  3200.0  3200.0  3200.0  242.797874  243.288183    210.126706
24  3328.0  3328.0  3328.0  226.822374  230.464339    203.194836
25  3456.0  3456.0  3456.0  238.845553  243.087739    208.013721
26  3584.0  3584.0  3584.0  243.178888  231.089281    204.957272
27  3712.0  3712.0  3712.0  234.690191  241.589704    214.537433
28  3840.0  3840.0  3840.0  232.122784  234.405964    209.603423
29  3968.0  3968.0  3968.0  231.017081  242.080959    203.055945
30  4096.0  4096.0  4096.0  240.965404  239.420671    210.661534
CODEJIN commented 1 month ago

Hi, here is my code:

        pid = tl.program_id(axis=0)
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % group_size_m) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

        x_block_pointer = tl.make_block_ptr(
            base= x_pointer,
            shape= (M, K),
            strides= (stride_x_m, stride_x_k),
            offsets= (pid_m * BLOCK_SIZE_M, 0),
            block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_K),
            order= (1, 0)
            )
        weights_block_pointer = tl.make_block_ptr(
            base= weights_pointer,
            shape= (K, N),
            strides= (stride_weight_k, stride_weight_n),
            offsets= (0, pid_n * BLOCK_SIZE_N),
            block_shape= (BLOCK_SIZE_K, BLOCK_SIZE_N),
            order= (0, 1)
            )

        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype= tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            x = tl.load(x_block_pointer, boundary_check= (1, 0), padding_option= 'zero')    # check 2 padding zero
            weights = tl.load(weights_block_pointer, boundary_check= (0, 1), padding_option= 'zero')
            accumulator += tl.dot(x, weights)
            x_block_pointer = tl.advance(x_block_pointer, (0, BLOCK_SIZE_K))
            weights_block_pointer = tl.advance(weights_block_pointer, (BLOCK_SIZE_K, 0))

        y = accumulator.to(tl.float16)
        y_block_pointer = tl.make_block_ptr(
            base= y_pointer,
            shape= (M, N),
            strides= (stride_y_m, stride_y_n),
            offsets= (pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
            block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_N),
            order= (1, 0)
            )
        tl.store(y_block_pointer, y, boundary_check= (1, 0))

When I compare, pid_m equation and boundary_check order of a are different. But I am not sure these factors affect the performance.

My environment is WSL, RTX4090, torch == 2.3.1, triton == 2.3.1

tlogn commented 1 month ago

Hi, here is my code:

        pid = tl.program_id(axis=0)
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % group_size_m) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

        x_block_pointer = tl.make_block_ptr(
            base= x_pointer,
            shape= (M, K),
            strides= (stride_x_m, stride_x_k),
            offsets= (pid_m * BLOCK_SIZE_M, 0),
            block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_K),
            order= (1, 0)
            )
        weights_block_pointer = tl.make_block_ptr(
            base= weights_pointer,
            shape= (K, N),
            strides= (stride_weight_k, stride_weight_n),
            offsets= (0, pid_n * BLOCK_SIZE_N),
            block_shape= (BLOCK_SIZE_K, BLOCK_SIZE_N),
            order= (0, 1)
            )

        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype= tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            x = tl.load(x_block_pointer, boundary_check= (1, 0), padding_option= 'zero')    # check 2 padding zero
            weights = tl.load(weights_block_pointer, boundary_check= (0, 1), padding_option= 'zero')
            accumulator += tl.dot(x, weights)
            x_block_pointer = tl.advance(x_block_pointer, (0, BLOCK_SIZE_K))
            weights_block_pointer = tl.advance(weights_block_pointer, (BLOCK_SIZE_K, 0))

        y = accumulator.to(tl.float16)
        y_block_pointer = tl.make_block_ptr(
            base= y_pointer,
            shape= (M, N),
            strides= (stride_y_m, stride_y_n),
            offsets= (pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
            block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_N),
            order= (1, 0)
            )
        tl.store(y_block_pointer, y, boundary_check= (1, 0))

When I compare, pid_m equation and boundary_check order of a are different. But I am not sure these factors affect the performance.

My environment is WSL, RTX4090, torch == 2.3.1, triton == 2.3.1

Hi, there! Finally I find out that the performance problem was probably caused by torch.zeros. When I replace torch.zeros with torch.empty, it runs faster much more. This is because torch.empty simply allocates memory without performing a memset operation. Nonetheless, its performance is slower compared to the tutorial on A100, yet faster than the tutorial on H100. Maybe make_block_ptr could make good use of TMA? I am uncertain whether there are any differences between our code or environment. Perhaps utilizing IR analysis would be a more efficient approach, but I am not well-versed in it. Do you have any advice to offer ?