Open tlogn opened 1 month ago
Dear @tlogn ,
Hello, I checked your code and found an issue in one part.
a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(K, 1), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(1, K), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
In this section, the strides are set to (K, 1), (1, K), and (N, 1), respectively. However, the strides need to be calculated and input separately. It should be adjusted to (stride_am, stride_ak), (stride_bk, stride_bn), and (stride_cm, stride_cn) as in the tutorial code. When I modified and checked this part, the speed changed to be almost similar to the tutorial, as shown below. 'Triton_Block_Stride1' is yours.
M N K Torch Triton Triton_Block Triton_Block_Stride1
0 256.0 256.0 256.0 4.096000 4.096000 4.096000 0.780190
1 384.0 384.0 384.0 13.824000 12.288000 12.288000 11.059200
2 512.0 512.0 512.0 26.214401 23.831273 23.831273 21.845333
3 640.0 640.0 640.0 42.666665 36.571428 36.571428 34.133334
4 768.0 768.0 768.0 52.043293 55.296000 52.043293 44.236801
5 896.0 896.0 896.0 63.860363 73.943582 70.246402 61.000942
6 1024.0 1024.0 1024.0 83.886082 99.864382 95.325090 74.731472
7 1152.0 1152.0 1152.0 85.313826 87.823057 85.313826 87.823057
8 1280.0 1280.0 1280.0 85.333330 107.789478 107.789478 93.090908
9 1408.0 1408.0 1408.0 111.260738 132.970149 132.970149 109.035523
10 1536.0 1536.0 1536.0 110.592000 115.971541 112.347429 99.688560
11 1664.0 1664.0 1664.0 124.984884 134.312118 132.336939 116.868992
12 1792.0 1792.0 1792.0 119.568337 122.167649 120.854018 118.309723
13 1920.0 1920.0 1920.0 138.196817 139.636368 139.636368 132.923078
14 2048.0 2048.0 2048.0 156.796411 158.275623 156.796411 149.796569
15 2176.0 2176.0 2176.0 138.783781 146.887946 144.774450 128.997748
16 2304.0 2304.0 2304.0 137.286620 139.695152 138.882977 131.977196
17 2432.0 2432.0 2432.0 153.521664 154.365184 154.365184 138.396372
18 2560.0 2560.0 2560.0 129.007867 148.945453 146.941707 137.680676
19 2688.0 2688.0 2688.0 142.071373 162.802816 161.417260 146.459680
20 2816.0 2816.0 2816.0 155.765024 158.022489 158.022489 150.393823
21 2944.0 2944.0 2944.0 115.628846 154.770282 153.341637 140.383190
22 3072.0 3072.0 3072.0 126.100589 167.523970 165.564625 152.212641
23 3200.0 3200.0 3200.0 136.460557 164.102564 164.948460 150.234737
24 3328.0 3328.0 3328.0 147.523150 162.142557 160.694855 148.435663
25 3456.0 3456.0 3456.0 159.331159 160.600739 160.921302 153.309371
26 3584.0 3584.0 3584.0 130.312159 159.707629 158.580935 147.161035
27 3712.0 3712.0 3712.0 138.939282 159.588392 160.091903 147.124220
28 3840.0 3840.0 3840.0 148.845220 159.815035 158.441257 147.455999
29 3968.0 3968.0 3968.0 158.872396 160.136403 160.136403 153.483193
30 4096.0 4096.0 4096.0 169.466833 170.111186 170.978001 162.897953
Dear @tlogn ,
Hello, I checked your code and found an issue in one part.
a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(K, 1), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(1, K), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
In this section, the strides are set to (K, 1), (1, K), and (N, 1), respectively. However, the strides need to be calculated and input separately. It should be adjusted to (stride_am, stride_ak), (stride_bk, stride_bn), and (stride_cm, stride_cn) as in the tutorial code. When I modified and checked this part, the speed changed to be almost similar to the tutorial, as shown below. 'Triton_Block_Stride1' is yours.
M N K Torch Triton Triton_Block Triton_Block_Stride1 0 256.0 256.0 256.0 4.096000 4.096000 4.096000 0.780190 1 384.0 384.0 384.0 13.824000 12.288000 12.288000 11.059200 2 512.0 512.0 512.0 26.214401 23.831273 23.831273 21.845333 3 640.0 640.0 640.0 42.666665 36.571428 36.571428 34.133334 4 768.0 768.0 768.0 52.043293 55.296000 52.043293 44.236801 5 896.0 896.0 896.0 63.860363 73.943582 70.246402 61.000942 6 1024.0 1024.0 1024.0 83.886082 99.864382 95.325090 74.731472 7 1152.0 1152.0 1152.0 85.313826 87.823057 85.313826 87.823057 8 1280.0 1280.0 1280.0 85.333330 107.789478 107.789478 93.090908 9 1408.0 1408.0 1408.0 111.260738 132.970149 132.970149 109.035523 10 1536.0 1536.0 1536.0 110.592000 115.971541 112.347429 99.688560 11 1664.0 1664.0 1664.0 124.984884 134.312118 132.336939 116.868992 12 1792.0 1792.0 1792.0 119.568337 122.167649 120.854018 118.309723 13 1920.0 1920.0 1920.0 138.196817 139.636368 139.636368 132.923078 14 2048.0 2048.0 2048.0 156.796411 158.275623 156.796411 149.796569 15 2176.0 2176.0 2176.0 138.783781 146.887946 144.774450 128.997748 16 2304.0 2304.0 2304.0 137.286620 139.695152 138.882977 131.977196 17 2432.0 2432.0 2432.0 153.521664 154.365184 154.365184 138.396372 18 2560.0 2560.0 2560.0 129.007867 148.945453 146.941707 137.680676 19 2688.0 2688.0 2688.0 142.071373 162.802816 161.417260 146.459680 20 2816.0 2816.0 2816.0 155.765024 158.022489 158.022489 150.393823 21 2944.0 2944.0 2944.0 115.628846 154.770282 153.341637 140.383190 22 3072.0 3072.0 3072.0 126.100589 167.523970 165.564625 152.212641 23 3200.0 3200.0 3200.0 136.460557 164.102564 164.948460 150.234737 24 3328.0 3328.0 3328.0 147.523150 162.142557 160.694855 148.435663 25 3456.0 3456.0 3456.0 159.331159 160.600739 160.921302 153.309371 26 3584.0 3584.0 3584.0 130.312159 159.707629 158.580935 147.161035 27 3712.0 3712.0 3712.0 138.939282 159.588392 160.091903 147.124220 28 3840.0 3840.0 3840.0 148.845220 159.815035 158.441257 147.455999 29 3968.0 3968.0 3968.0 158.872396 160.136403 160.136403 153.483193 30 4096.0 4096.0 4096.0 169.466833 170.111186 170.978001 162.897953
Dear @CODEJIN , I update my code as you said, but got even worse performance. Could you provide me with a clearer view of your code and environment, please? Here are the revised code and results.
@triton.autotune(
configs=get_cuda_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def _block_matmul_kernel(a_ptr, b_ptr, o_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
# pid_m = tl.program_id(0)
# pid_n = tl.program_id(1)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
a_tile = tl.make_block_ptr(a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m*BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile = tl.make_block_ptr(b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, pid_n*BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _ in range(0, tl.cdiv(K, BLOCK_K)):
a_value = tl.load(a_tile, boundary_check=(0, 1), padding_option="zero")
b_value = tl.load(b_tile, boundary_check=(0, 1), padding_option="zero")
accumulator += tl.dot(a_value, b_value)
a_tile = tl.advance(a_tile, (0, BLOCK_K))
b_tile = tl.advance(b_tile, (BLOCK_K, 0))
o_tile = tl.make_block_ptr(o_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m*BLOCK_M, pid_n*BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
tl.store(o_tile, accumulator.to(o_tile.dtype.element_ty), boundary_check=(0, 1))
def matmul_block(a, b):
M, K = a.shape
N = b.shape[1]
o = torch.zeros(M, N, device=a.device, dtype=a.dtype)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), )
_block_matmul_kernel[grid](a, b, o,
M, N, K,
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
o.stride(0), o.stride(1), #
)
return o
triton version 2.3.1:
matmul-performance-fp16:
M N K torch Triton triton_block
0 256.0 256.0 256.0 3.771856 3.615779 1.849340
1 384.0 384.0 384.0 10.025337 11.059200 5.019779
2 512.0 512.0 512.0 23.629882 22.133530 9.927347
3 640.0 640.0 640.0 37.321185 35.121115 16.532796
4 768.0 768.0 768.0 58.254222 53.017888 24.661630
5 896.0 896.0 896.0 72.747082 75.813993 34.636130
6 1024.0 1024.0 1024.0 104.530941 95.122413 45.221606
7 1152.0 1152.0 1152.0 132.710398 127.913641 57.113858
8 1280.0 1280.0 1280.0 154.202351 162.017307 65.601599
9 1408.0 1408.0 1408.0 154.386578 130.435009 71.822489
10 1536.0 1536.0 1536.0 179.471019 156.742163 63.980909
11 1664.0 1664.0 1664.0 182.141161 179.865814 73.563727
12 1792.0 1792.0 1792.0 171.922348 213.956912 81.169391
13 1920.0 1920.0 1920.0 205.752551 172.261682 82.163447
14 2048.0 2048.0 2048.0 233.625284 199.136092 87.367114
15 2176.0 2176.0 2176.0 221.519344 222.437566 78.140610
16 2304.0 2304.0 2304.0 244.299100 246.108155 86.560062
17 2432.0 2432.0 2432.0 216.007410 214.820276 90.874642
18 2560.0 2560.0 2560.0 236.272205 233.120509 88.937745
19 2688.0 2688.0 2688.0 211.473485 212.845488 88.719322
20 2816.0 2816.0 2816.0 228.496180 228.571024 91.368556
21 2944.0 2944.0 2944.0 244.856905 246.122858 93.309519
22 3072.0 3072.0 3072.0 226.889478 227.730724 90.780797
23 3200.0 3200.0 3200.0 240.178245 244.435158 90.927254
24 3328.0 3328.0 3328.0 228.793481 230.073050 90.962703
25 3456.0 3456.0 3456.0 242.152268 241.653251 93.437041
26 3584.0 3584.0 3584.0 243.487567 230.459950 94.570013
27 3712.0 3712.0 3712.0 231.914907 243.244180 97.588762
28 3840.0 3840.0 3840.0 231.016646 232.825259 92.446485
29 3968.0 3968.0 3968.0 230.887293 241.370149 93.729380
30 4096.0 4096.0 4096.0 245.342590 237.632358 98.889472
triton version 3.0.0
matmul-performance-fp16:
M N K torch Triton triton_block
0 256.0 256.0 256.0 3.785473 3.718355 2.818753
1 384.0 384.0 384.0 10.011157 10.922667 8.466373
2 512.0 512.0 512.0 23.563505 22.310127 16.777215
3 640.0 640.0 640.0 39.290168 36.008791 29.627485
4 768.0 768.0 768.0 60.624308 55.404212 46.110018
5 896.0 896.0 896.0 73.101942 75.942051 64.967771
6 1024.0 1024.0 1024.0 103.723132 97.826335 77.492913
7 1152.0 1152.0 1152.0 129.561335 123.611238 99.740591
8 1280.0 1280.0 1280.0 158.108560 157.728042 126.030769
9 1408.0 1408.0 1408.0 154.113809 131.467095 116.537631
10 1536.0 1536.0 1536.0 175.575510 157.286398 137.936912
11 1664.0 1664.0 1664.0 178.749333 179.027156 153.417788
12 1792.0 1792.0 1792.0 171.799173 213.069643 179.920737
13 1920.0 1920.0 1920.0 207.004214 172.395948 152.751385
14 2048.0 2048.0 2048.0 231.011580 199.431987 171.524248
15 2176.0 2176.0 2176.0 221.671854 225.002361 191.425902
16 2304.0 2304.0 2304.0 243.909341 247.863785 206.040940
17 2432.0 2432.0 2432.0 218.740347 215.231714 187.374491
18 2560.0 2560.0 2560.0 242.165363 238.529575 203.173030
19 2688.0 2688.0 2688.0 212.957504 213.820285 185.832487
20 2816.0 2816.0 2816.0 228.048145 228.664652 196.820563
21 2944.0 2944.0 2944.0 245.026188 246.065880 189.625821
22 3072.0 3072.0 3072.0 227.816589 229.868609 202.927459
23 3200.0 3200.0 3200.0 242.797874 243.288183 210.126706
24 3328.0 3328.0 3328.0 226.822374 230.464339 203.194836
25 3456.0 3456.0 3456.0 238.845553 243.087739 208.013721
26 3584.0 3584.0 3584.0 243.178888 231.089281 204.957272
27 3712.0 3712.0 3712.0 234.690191 241.589704 214.537433
28 3840.0 3840.0 3840.0 232.122784 234.405964 209.603423
29 3968.0 3968.0 3968.0 231.017081 242.080959 203.055945
30 4096.0 4096.0 4096.0 240.965404 239.420671 210.661534
Hi, here is my code:
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % group_size_m) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
x_block_pointer = tl.make_block_ptr(
base= x_pointer,
shape= (M, K),
strides= (stride_x_m, stride_x_k),
offsets= (pid_m * BLOCK_SIZE_M, 0),
block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_K),
order= (1, 0)
)
weights_block_pointer = tl.make_block_ptr(
base= weights_pointer,
shape= (K, N),
strides= (stride_weight_k, stride_weight_n),
offsets= (0, pid_n * BLOCK_SIZE_N),
block_shape= (BLOCK_SIZE_K, BLOCK_SIZE_N),
order= (0, 1)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype= tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
x = tl.load(x_block_pointer, boundary_check= (1, 0), padding_option= 'zero') # check 2 padding zero
weights = tl.load(weights_block_pointer, boundary_check= (0, 1), padding_option= 'zero')
accumulator += tl.dot(x, weights)
x_block_pointer = tl.advance(x_block_pointer, (0, BLOCK_SIZE_K))
weights_block_pointer = tl.advance(weights_block_pointer, (BLOCK_SIZE_K, 0))
y = accumulator.to(tl.float16)
y_block_pointer = tl.make_block_ptr(
base= y_pointer,
shape= (M, N),
strides= (stride_y_m, stride_y_n),
offsets= (pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_N),
order= (1, 0)
)
tl.store(y_block_pointer, y, boundary_check= (1, 0))
When I compare, pid_m
equation and boundary_check order of a
are different. But I am not sure these factors affect the performance.
My environment is WSL
, RTX4090
, torch == 2.3.1
, triton == 2.3.1
Hi, here is my code:
pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % group_size_m) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m x_block_pointer = tl.make_block_ptr( base= x_pointer, shape= (M, K), strides= (stride_x_m, stride_x_k), offsets= (pid_m * BLOCK_SIZE_M, 0), block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_K), order= (1, 0) ) weights_block_pointer = tl.make_block_ptr( base= weights_pointer, shape= (K, N), strides= (stride_weight_k, stride_weight_n), offsets= (0, pid_n * BLOCK_SIZE_N), block_shape= (BLOCK_SIZE_K, BLOCK_SIZE_N), order= (0, 1) ) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype= tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): x = tl.load(x_block_pointer, boundary_check= (1, 0), padding_option= 'zero') # check 2 padding zero weights = tl.load(weights_block_pointer, boundary_check= (0, 1), padding_option= 'zero') accumulator += tl.dot(x, weights) x_block_pointer = tl.advance(x_block_pointer, (0, BLOCK_SIZE_K)) weights_block_pointer = tl.advance(weights_block_pointer, (BLOCK_SIZE_K, 0)) y = accumulator.to(tl.float16) y_block_pointer = tl.make_block_ptr( base= y_pointer, shape= (M, N), strides= (stride_y_m, stride_y_n), offsets= (pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), block_shape= (BLOCK_SIZE_M, BLOCK_SIZE_N), order= (1, 0) ) tl.store(y_block_pointer, y, boundary_check= (1, 0))
When I compare,
pid_m
equation and boundary_check order ofa
are different. But I am not sure these factors affect the performance.My environment is
WSL
,RTX4090
,torch == 2.3.1
,triton == 2.3.1
Hi, there!
Finally I find out that the performance problem was probably caused by torch.zeros. When I replace torch.zeros with torch.empty, it runs faster much more. This is because torch.empty simply allocates memory without performing a memset operation.
Nonetheless, its performance is slower compared to the tutorial on A100, yet faster than the tutorial on H100. Maybe make_block_ptr
could make good use of TMA? I am uncertain whether there are any differences between our code or environment. Perhaps utilizing IR analysis would be a more efficient approach, but I am not well-versed in it. Do you have any advice to offer ?
Hi everyone! I implement matmul with make_block_ptr, but perform worse than official tutorial example. At first, I think it must be caused by the L2 cache optimization. After I apply pid_m and pid_n with group_size, it doesn't work. Could you please help analysis? Below is my triton code.
env:
A100-80G-SXM
,triton==2.3.1
,torch== 2.3.1
Below is result.