I'm generating code for matmul (48, 5120, 13824).
I modify the 03-matrix-matmul.py in the tutorial to support split k strategy.
But I find the performance is much slower than before with the same config and split_k = 1.
I use the following config for the split k kernel:
config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4, "split_k": 1}:
, and I use this for the matmul kernel.
config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4}:
they act the same since split_k is set to 1, but the performance is 3.x slower. I'm quite confused about this.
I can see the memory pattern has changed, but why does this change happen??
original kernel
my split k kernel
Could anyone give some advice? Thank you!
This is the slow split k kernel:
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
a_scale_ptr, w_scale_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
split_k: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# ----------------------------------------------------------
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
pid_k = tl.program_id(axis=1)
# num_blocks_k = tl.cdiv(K, BLOCK_SIZE_K * split_k)
total_block_num_k = tl.cdiv(K, BLOCK_SIZE_K)
blocks_per_split = total_block_num_k // split_k
offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * blocks_per_split * BLOCK_SIZE_K
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if pid_k == split_k - 1:
remaining_k = K - (split_k - pid_k - 1) * blocks_per_split * BLOCK_SIZE_K
num_blocks_k = total_block_num_k - (split_k -1) * blocks_per_split
else:
remaining_k = (pid_k + 1) * blocks_per_split * BLOCK_SIZE_K
num_blocks_k = blocks_per_split
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
a_scale = tl.load(a_scale_ptr)
w_scale = tl.load(w_scale_ptr)
for k in range(0, num_blocks_k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator * a_scale * w_scale
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
tl.atomic_add(c_ptrs, c, mask=c_mask)
the whole file is here:
split_k.py
import torch
import json
import triton
import triton.language as tl
def cdiv(a, b):
return (a + b - 1) // b
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
a_scale_ptr, w_scale_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
split_k: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # 8 rows
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # 4 cols
# pid = 8, we want pid_m = 0, pid_n = 1
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# ----------------------------------------------------------
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
pid_k = tl.program_id(axis=1) # 0
# num_blocks_k = tl.cdiv(K, BLOCK_SIZE_K * split_k)
total_block_num_k = tl.cdiv(K, BLOCK_SIZE_K)
if total_block_num_k % split_k == 0:
# each split process equal amount of data
blocks_per_split = total_block_num_k // split_k
else:
blocks_per_split = total_block_num_k // split_k
offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * blocks_per_split * BLOCK_SIZE_K
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if pid_k == split_k - 1:
remaining_k = K - (split_k - pid_k - 1) * blocks_per_split * BLOCK_SIZE_K
num_blocks_k = total_block_num_k - (split_k -1) * blocks_per_split
else:
remaining_k = (pid_k + 1) * blocks_per_split * BLOCK_SIZE_K
num_blocks_k = blocks_per_split
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
a_scale = tl.load(a_scale_ptr)
w_scale = tl.load(w_scale_ptr)
for k in range(0, num_blocks_k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < remaining_k - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator * a_scale * w_scale
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
tl.atomic_add(c_ptrs, c, mask=c_mask)
def benchmark(a, b, d, a_scale, w_scale, config):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['split_k'])
matmul_kernel[grid](
a, b, d, #
a_scale, w_scale,
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
d.stride(0), d.stride(1), #
**config
)
# print(grid(config))
for i in range(5):
matmul_kernel[grid](
a, b, d, #
a_scale, w_scale,
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
d.stride(0), d.stride(1), #
**config
)
return 1
c = torch.zeros_like(d)
cnt = 10
import time
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for i in range(cnt):
matmul_kernel[grid](
a, b, c, #
a_scale, w_scale,
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
**config
)
torch.cuda.synchronize()
start = time.time()
for i in range(10):
g.replay()
torch.cuda.synchronize()
end = time.time()
torch.cuda.nvtx.range_pop()
return 1000 * (end - start)/cnt/10
def generate_config():
configs = []
block_size_m = 32 # 16 is min size
while block_size_m <= 128:
block_size_n = 32
while block_size_n < 512:
block_size_k = 32
while block_size_k < 512:
num_stages = 2
while num_stages <= 6:
if (block_size_m * block_size_k + block_size_k * block_size_n) * (num_stages) + block_size_n * block_size_m > 116224:
break
num_warps = 2
while num_warps < 32:
split_k = 1
while split_k <= 16:
config = {'BLOCK_SIZE_M': block_size_m,
'BLOCK_SIZE_N':block_size_n,
'BLOCK_SIZE_K':block_size_k,
'num_stages':num_stages,
'num_warps':num_warps,
'split_k': split_k
}
configs.append(config)
split_k *= 2
num_warps *= 2
num_stages += 1
block_size_k *= 2
block_size_n *= 2
block_size_m *= 2
return configs
def torch_fp8(a, b, a_scale, w_scale):
import time
cnt = 10
for i in range(5):
ret, _ = torch._scaled_mm(
a,
b,
scale_a = a_scale,
scale_b = w_scale,
out_dtype=torch.float16,
)
torch.cuda.synchronize()
start = time.time()
for i in range(cnt):
ret, _ = torch._scaled_mm(
a,
b,
scale_a = a_scale,
scale_b = w_scale,
out_dtype=torch.float16,
)
torch.cuda.synchronize()
end = time.time()
duration = 1000 * (end - start)/cnt
return duration, ret
def cutlass_fp8(qinput, weight, x_scale, weight_scale):
import time
cnt = 10
for i in range(5):
ret = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=torch.float16,
scale_a=x_scale,
scale_b=weight_scale)
torch.cuda.synchronize()
start = time.time()
for i in range(cnt):
ret = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=torch.float16,
scale_a=x_scale,
scale_b=weight_scale)
torch.cuda.synchronize()
end = time.time()
duration = 1000 * (end - start)/cnt
return duration, ret
import torch.nn.functional as F
def compare_result(torch_output, triton_output):
print(torch_output)
print(triton_output)
torch_output = torch_output.to(torch.float32)
triton_output = triton_output.to(torch.float32)
diff = torch.abs(triton_output - torch_output)
relative_diff = torch.abs(diff / torch_output)
idx = relative_diff.argmax()
print('diff avg max min', "%.4f"%diff.mean().item(), "%.4f"%diff.max().item(), "%.4f"%diff.min().item())
print('relative diff avg max min', "%.4f"%relative_diff.mean().item(), "%.4f"%relative_diff.max().item(), "%.4f"%relative_diff.min().item())
cos_sim = F.cosine_similarity(torch_output.reshape(-1),
triton_output.reshape(-1), dim=0)
print("cos_sim", cos_sim.item())
return cos_sim.item(), relative_diff.max().item()
def check_consistency(M, N, K):
torch.manual_seed(0)
dtype=torch.float8_e4m3fn
a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
#a = torch.ones((M, K), device='cuda', dtype=torch.float16).to(dtype)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
#b = torch.ones((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
a_scale = torch.randn((), device='cuda', dtype=torch.float32)
w_scale = torch.randn((), device='cuda', dtype=torch.float32)
configs = generate_config()
configs.reverse()
print('begin to tune', M, N, K)
best_cost = 1000
best_config = None
# cublas_cost, d = torch_fp8(a, b)
# d = torch.matmul(a, b)
cublas_cost = 10
print('cublas cost', '%.4f'%cublas_cost)
for i, config in enumerate(configs):
print(config)
try:
torch.manual_seed(i)
dtype=torch.float16
dtype=torch.float8_e4m3fn
a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
time_cost = benchmark(a, b, c, a_scale, w_scale, config)
# d = torch.matmul(a, b)
_, d = torch_fp8(a, b, a_scale, w_scale)
except Exception as ex:
print(ex)
continue
if time_cost < best_cost:
best_config = config
best_cost = time_cost
print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)
cos_sim, max_diff = compare_result(d, c)
if i == 30:
break
'''
'''
compare_result(d, c)
print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost), cos_sim)
# print("dff for ", M, N, K, "%.6f"%cos_sim, "%.6f"%max_diff, cos_sim > 0.9999)
if cos_sim < 0.9999:
print("fuck!!!!")
return best_config, best_cost
def tune_gemm(M, N, K):
torch.manual_seed(0)
dtype=torch.float8_e4m3fn
a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
a_scale = torch.randn((), device='cuda', dtype=torch.float32)
w_scale = torch.randn((), device='cuda', dtype=torch.float32)
configs = generate_config()
configs.reverse()
print('begin to tune', M, N, K)
best_cost = 1000
best_config = None
cublas_cost, d = torch_fp8(a, b, a_scale, w_scale)
print('cublas cost', '%.4f'%cublas_cost)
for i, config in enumerate(configs):
if config != {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4, "split_k": 1}:
continue
try:
c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
time_cost = benchmark(a, b, c, a_scale, w_scale, config)
except Exception as ex:
# print(ex)
continue
if time_cost < best_cost:
best_config = config
best_cost = time_cost
print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)
c = torch.zeros((M, N), device='cuda', dtype=torch.float16)
time_cost = benchmark(a, b, c, a_scale, w_scale, config)
cos_sim, max_diff = compare_result(d, c)
print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost))
if cos_sim < 0.9999:
print("fuck!!!!")
return best_config, best_cost
def tune_random():
import random
for i in range(1):
M = random.randint(512//16,10240//8) * 16
N = random.randint(512//16,10240//8) * 16
K = random.randint(512//16,10240//8) * 16
#tune_gemm(5133, 513, 511 + 32)
check_consistency(M, N, K)
exit(0)
import sys
if __name__ == '__main__':
# tune_random()
tune_gemm(48, 5120, 13824)
exit(0)
result = {}
# n_k_list = [(15360, 5120), (5120, 5120), (5120, 13824), (27648, 5120)]
n_k_list = [(5120, 13824)]
for n_k in n_k_list:
N, K = n_k
for i in range(8, 257, 8):
best_config, best_cost = tune_gemm(i, N, K)
result[i] = (best_config, best_cost)
import json
with open(f'best_config_{N}_{K}.json', 'a+') as f:
json.dump(result, f)
column.py
import torch
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
a_scale_ptr, w_scale_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # 8 rows
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # 4 cols
# pid = 8, we want pid_m = 0, pid_n = 1
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# print(M, BLOCK_SIZE_M)
'''
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
'''
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
a_scale = tl.load(a_scale_ptr)
w_scale = tl.load(w_scale_ptr)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
#accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float16)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator * a_scale * w_scale
c = accumulator.to(tl.float16)
#c = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def benchmark(a, b, c, a_scale, w_scale, config):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
for i in range(6):
matmul_kernel[grid](
a, b, c, #
a_scale, w_scale,
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
**config
)
return 1
cnt = 10
import time
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for i in range(cnt):
matmul_kernel[grid](
a, b, c, #
a_scale, w_scale,
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
**config
)
torch.cuda.synchronize()
start = time.time()
for i in range(10):
g.replay()
torch.cuda.synchronize()
end = time.time()
torch.cuda.nvtx.range_pop()
return 1000 * (end - start)/cnt/10
def generate_config():
configs = []
block_size_m = 32
while block_size_m < 512:
block_size_n = 32
while block_size_n < 512:
block_size_k = 32
while block_size_k < 512:
num_stages = 2
while num_stages < 7:
if (block_size_m * block_size_k + block_size_k * block_size_n) * (num_stages) + block_size_n * block_size_m > 116224:
break
num_warps = 2
while num_warps < 32:
config = {'BLOCK_SIZE_M': block_size_m,
'BLOCK_SIZE_N':block_size_n,
'BLOCK_SIZE_K':block_size_k,
'num_stages':num_stages,
'num_warps':num_warps
}
configs.append(config)
num_warps *= 2
num_stages += 1
block_size_k *= 2
block_size_n *= 2
block_size_m *= 2
return configs
import json
start_config = {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'num_stages': 3, 'num_warps': 2}
target_config = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'num_stages': 3, 'num_warps': 8}
def torch_fp8(a, b, a_scale, w_scale):
import time
cnt = 10
for i in range(5):
ret, _ = torch._scaled_mm(
a,
b,
scale_a = a_scale,
scale_b = w_scale,
out_dtype=torch.float16,
)
torch.cuda.synchronize()
start = time.time()
for i in range(cnt):
ret, _ = torch._scaled_mm(
a,
b,
out_dtype=torch.float16,
)
torch.cuda.synchronize()
end = time.time()
duration = 1000 * (end - start)/cnt
return duration, ret
import torch.nn.functional as F
# b is right output
def compare_result(torch_output, triton_output):
torch_output = torch_output.to(torch.float32)
triton_output = triton_output.to(torch.float32)
diff = torch.abs(triton_output - torch_output)
relative_diff = torch.abs(diff / torch_output)
idx = relative_diff.argmax()
print('abs diff avg max min', diff.mean().item(), diff.max().item(), diff.min().item())
print('relative diff avg max min', relative_diff.mean().item(), relative_diff.max().item(), relative_diff.min().item())
print("right:\n", torch_output)
print("right:\n", triton_output)
# print('triton', triton_output[row][col].item())
cos_sim = F.cosine_similarity(torch_output.to(torch.float32).reshape(-1),
triton_output.to(torch.float32).reshape(-1), dim=0)
print("cos_sim", cos_sim.item())
def tune_gemm(M, N, K):
torch.manual_seed(0)
dtype=torch.float16
dtype=torch.float8_e4m3fn
a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(dtype)
#a = torch.ones((M, K), device='cuda', dtype=torch.float16).to(dtype)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
#b = torch.ones((N, K), device='cuda', dtype=torch.float16).T.to(dtype)
c = torch.randn((M, N), device='cuda', dtype=torch.float16)
a_scale = torch.ones(1, device='cuda', dtype=torch.float32)
w_scale = torch.ones(1, device='cuda', dtype=torch.float32)
configs = generate_config()
configs.reverse()
print('begin to tune', M, N, K)
best_cost = 1000
best_config = None
cublas_cost, d = torch_fp8(a, b, a_scale, w_scale)
print('cublas cost', '%.4f'%cublas_cost)
for i, config in enumerate(configs):
if config != {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4}:
continue
try:
time_cost = benchmark(a, b, c, a_scale, w_scale, config)
except Exception as ex:
continue
if time_cost < best_cost:
best_config = config
best_cost = time_cost
print(f"{i}/{len(configs)}", '%.4f'%time_cost, '/', '%.4f'%best_cost, config)
#break
time_cost = benchmark(a, b, c, a_scale, w_scale, best_config)
compare_result(d, c)
print("best config for", M, N, K, ":", best_config, '%.4f'%best_cost, "cublas", '%.4f'%cublas_cost, "speedup", '%.4f'%(cublas_cost/best_cost))
return best_config, best_cost
import sys
if __name__ == '__main__':
tune_gemm(48, 5120, 13824)
exit(0)
result = {}
n_k_list = [(15360, 5120), (5120, 5120), (5120, 13824), (27648, 5120)]
for n_k in n_k_list:
N, K = n_k
for i in range(8, 257, 8):
best_config, best_cost = tune_gemm(i, N, K)
result[i] = (best_config, best_cost)
import json
with open(f'best_config_{N}_{K}.json', 'a+') as f:
json.dump(result, f)
Here is my whole file, Could anyone give some advice?
I'm generating code for matmul (48, 5120, 13824). I modify the
03-matrix-matmul.py
in the tutorial to support split k strategy.But I find the performance is much slower than before with the same config and split_k = 1.
I use the following config for the split k kernel:
config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4, "split_k": 1}:
, and I use this for the matmul kernel.config = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "num_stages": 4, "num_warps": 4}:
they act the same since split_k is set to 1, but the performance is 3.x slower. I'm quite confused about this.
I can see the memory pattern has changed, but why does this change happen??
original kernel
my split k kernel
Could anyone give some advice? Thank you!
This is the slow split k kernel:
the whole file is here: split_k.py
column.py
Here is my whole file, Could anyone give some advice?
Thank you!