Open mgoin opened 5 months ago
Hello! Which Triton version are you using?
Can you try to build triton from source? Latest main commit should be fine.
We expect that N and K dim should be multiples of the block size. Try with N=K=8192. I'll share an example usage script soon.
I am using torch==2.3.0
and triton==2.3.0
, which is the default for the latest non-nightly torch. EDIT: I will try the nightly
I did mean to use 8192, not 8096, so thanks for that correction. However I still see NaNs after making that change.
I did another try doing M=N=K=512 and this did give some real numbers between the NaNs
Running with M=512, N=512, K=512
y_torch: tensor([[-15.2812, -14.4297, -6.2031, ..., 4.1797, -20.0000, 8.0625],
[ 70.8125, -12.3281, -6.3125, ..., -20.4062, -11.0312, 13.5625],
[-16.0781, -17.8906, 28.0469, ..., -54.7500, 3.0391, -21.0781],
...,
[-27.1250, -7.9062, -10.9375, ..., 29.9219, 30.6250, 1.7432],
[ 15.5312, -29.6719, 15.0703, ..., -41.6875, -36.7188, 41.8125],
[ 7.0078, 21.3906, -12.7578, ..., -50.5938, 26.6094, 42.5625]],
device='cuda:0', dtype=torch.float16)
y_triton: tensor([[ -inf, -inf, -inf, ..., inf, inf, 11288.],
[ inf, inf, -inf, ..., 3748., -inf, inf],
[-32608., -47968., -40800., ..., -37184., inf, -47424.],
...,
[ 33568., -inf, -inf, ..., -inf, -54528., -45696.],
[ -inf, inf, -32992., ..., inf, inf, inf],
[ -inf, 34592., inf, ..., inf, -4096., -24720.]],
device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[-15.3984, -14.9141, -5.8555, ..., 4.6094, -22.0469, 6.8867],
[ 71.1875, -12.4531, -6.8594, ..., -19.3438, -11.3984, 13.5078],
[-15.0234, -17.7656, 27.5625, ..., -54.5938, 2.2539, -22.8594],
...,
[-27.7031, -7.2266, -11.0703, ..., 30.2812, 29.4531, 1.6523],
[ 14.7812, -30.0625, 14.6328, ..., -42.6562, -35.9375, 43.0000],
[ 7.7227, 21.9375, -11.2109, ..., -50.2812, 26.8750, 43.1875]],
device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9990, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(nan, device='cuda:0', dtype=torch.float16)
Unfortunately I got the same NaN results using a fresh install of pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
, which resulted in these versions being installed
pytorch-triton 3.0.0+45fff310c8
torch 2.4.0.dev20240502+cu121
Follow the instructions to build Triton from source here: https://github.com/openai/triton?tab=readme-ov-file#install-from-source
The results look pretty close on my end, with some margin of error expected because of the downcast.
import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'
@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)
pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):
pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)
pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):
k_remaining = k - k_ * (block_k * split_k)
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk
acc.to(tl.float16)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
tl.atomic_add(c_ptrs, acc, mask=mask)
def gemm_split_k(a, b):
m, k = a.shape
_, n = b.shape
block_m = 64
block_n = 64
block_k = 512
num_stages = 3
num_warps = 8
split_k = 4
group_m = 8
total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k
grid = (total_programs_mn, total_programs_k)
c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
k = gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)
return c
if __name__ == '__main__':
torch.cuda.manual_seed(0)
m = 16
k = 8192
n = 8192
a = torch.randn((m, k), device="cuda", dtype=torch.float16)
b = torch.randn((k, n), device="cuda", dtype=torch.float16)
a = a.to(torch.float8_e4m3fn)
# pre-transpose b for efficiency.
b = b.T
b = b.to(torch.float8_e4m3fn)
triton_output = gemm_split_k(a, b)
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
print(f"triton_output_with_fp8_inputs={triton_output}")
print(f"torch_output={torch_output}")
>>>
triton_output_with_fp8_inputs=tensor([[ -4.2812, 31.1094, -75.3750, ..., -32.0000, 57.0625,
-37.2812],
[ 63.8125, -54.1250, -37.6875, ..., 54.9375, -5.5312,
24.8906],
[ 10.1250, 117.5000, 9.4688, ..., -15.3906, 89.5000,
8.0781],
...,
[ 17.4688, 58.4375, -118.2500, ..., -53.0625, 143.2500,
-62.5000],
[ 61.7188, 101.3125, 54.8750, ..., 100.0000, -2.4785,
-69.1250],
[ -84.3750, -44.5312, -86.8750, ..., -57.8750, -95.7500,
71.3125]], device='cuda:0', dtype=torch.float16)
torch_output=tensor([[ -4.2656, 31.2656, -75.2500, ..., -32.0312, 57.0938,
-37.3438],
[ 63.7188, -54.1875, -37.6875, ..., 55.0625, -5.5078,
24.8750],
[ 10.1250, 117.3125, 9.5469, ..., -15.3125, 89.4375,
8.1172],
...,
[ 17.7188, 58.4688, -118.2500, ..., -53.1562, 143.2500,
-62.5000],
[ 61.7812, 101.3750, 54.9062, ..., 100.0000, -2.4609,
-69.1250],
[ -84.5000, -44.5312, -86.8125, ..., -58.0000, -95.9375,
71.2500]], device='cuda:0', dtype=torch.float16)
Using your script I can get the same correct result on triton==2.3.0, so it seems building triton from source is not important. Hopefully it is just an issue with input strides/transposing, looking into this.
It seems that the issue was just the scaling. If I replace my scaling function to_float8()
used for generating per-tensor scales (needed for torch._scaled_mm
and generally for better accuracy) with just .to(torch.float8_e4m3fn)
like in your script, then I get proper output. Thanks for the help in debugging this.
It would be great if you could add support for per-tensor scaling as it is common and what we are supporting in vLLM.
For other folks who may stumble across this, here is my updated output and script.
Output:
Running with M=16, N=8192, K=8192
y_torch: tensor([[ 19.3906, -43.7188, 30.7188, ..., -63.3125, 32.9062,
16.8750],
[ -24.6406, 16.6562, 23.1875, ..., 29.7500, 33.2188,
210.7500],
[ -59.0000, -50.8438, -147.7500, ..., 95.1875, -36.2188,
-110.2500],
...,
[ 89.1875, 85.7500, -121.5625, ..., 18.7656, 5.5312,
-128.7500],
[ -38.2812, 173.8750, -144.7500, ..., -53.8750, 14.0078,
62.0000],
[ 174.0000, -58.0000, -57.7812, ..., 111.6250, -75.3750,
-135.0000]], device='cuda:0', dtype=torch.float16)
y_triton: tensor([[ 25.6875, -47.6562, 19.9844, ..., -55.6250, 28.8438,
17.3594],
[ -25.2812, 15.5938, 16.7188, ..., 26.0625, 29.9531,
209.2500],
[ -59.9062, -54.6250, -148.1250, ..., 102.0000, -32.7500,
-114.4375],
...,
[ 93.0625, 82.5000, -122.0000, ..., 24.6250, 6.4375,
-128.7500],
[ -42.8125, 168.5000, -140.8750, ..., -52.7500, 11.7188,
63.4375],
[ 181.7500, -54.2500, -52.5625, ..., 110.1875, -80.5000,
-138.6250]], device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[ 24.0469, -46.9375, 27.9219, ..., -61.7188, 28.3750,
19.2344],
[ -25.3594, 14.8750, 21.5625, ..., 23.1094, 30.6406,
208.5000],
[ -60.1250, -55.7500, -146.5000, ..., 97.6250, -36.5625,
-110.6250],
...,
[ 86.7500, 81.5625, -122.1875, ..., 18.4062, 5.9336,
-129.6250],
[ -41.8750, 166.7500, -142.5000, ..., -51.7188, 13.3203,
61.6875],
[ 179.1250, -54.0938, -56.9688, ..., 112.3125, -79.8125,
-139.5000]], device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(0.9990, device='cuda:0', dtype=torch.float16)
Script:
import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'
@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)
pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit()
def col_major(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
pid_m = pid % grid_m
pid_n = pid // grid_m
return pid_m, pid_n
@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):
pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)
pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):
k_remaining = k - k_ * (block_k * split_k)
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk
acc.to(tl.float16)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
tl.atomic_add(c_ptrs, acc, mask=mask)
def gemm_split_k(a, b):
m, k = a.shape
_, n = b.shape
block_m = 64
block_n = 64
block_k = 512
num_stages = 3
num_warps = 8
split_k = 4
group_m = 8
total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k
grid = (total_programs_mn, total_programs_k)
c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
k = gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)
return c
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
scale = finfo.max / x.abs().max().clamp(min=1e-12)
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
dtype = torch.float16
qdtype = torch.float8_e4m3fn
m = 16
n = 8192
k = 8192
print(f"Running with M={m}, N={n}, K={k}")
# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((k, n), dtype=dtype, device='cuda')
x_fp8_scaled, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8_scaled, w_inv_s = to_float8(w, dtype=qdtype)
x_fp8 = x.to(qdtype)
w_fp8 = w.T.to(qdtype)
y_torch, _ = torch._scaled_mm(x_fp8_scaled, w_fp8_scaled.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8)
y_fp16 = torch.nn.functional.linear(x, w)
print("y_torch:", y_torch)
print("y_triton:", y_triton)
print("y_fp16:", y_fp16)
print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))
While I have you here @AdnanHoque , could you possibly share your setup for tuning the triton kernel?
After carefully tuning the other relevant hyperparameters for our kernel such as tile sizes, number of warps and the number of pipeline stages to Llama3-70B problem sizes we were able to produce up to 1.94x speedup over the Triton base implementation.
Hey thanks for the suggestion! Try this script for per tensor scale support. Thanks @cyang49 for getting to this so quickly. We'll push this into main soon.
We haven't done any performance analysis yet, but since the scaling is done in SRAM this shouldn't add too much overhead:
Code:
import torch
import triton
import triton.language as tl
import time
import os
os.environ['ENABLE_TMA'] = '1'
@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)
pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
scale_a, scale_b,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):
pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)
pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):
k_remaining = k - k_ * (block_k * split_k)
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk
acc = scale_a * scale_b * acc
acc.to(tl.float16)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
tl.atomic_add(c_ptrs, acc, mask=mask)
def gemm_split_k(a, b, scale_a:float=1.0, scale_b:float=1.0):
assert a.shape[1] == b.shape[0]
m, k = a.shape
_, n = b.shape
block_m = 64
block_n = 64
block_k = 512
num_stages = 3
num_warps = 8
split_k = 4
group_m = 8
total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k
grid = (total_programs_mn, total_programs_k)
c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
k = gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale_a, scale_b,
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)
return c
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
scale = finfo.max / x.abs().max().clamp(min=1e-12)
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
dtype = torch.float16
qdtype = torch.float8_e4m3fn
torch.cuda.manual_seed(0)
m = 64
n = 4096
k = 4096
# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((n, k), dtype=dtype, device='cuda')
x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)
y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
y_fp16 = torch.nn.functional.linear(x, w)
print("y_torch:", y_torch)
print("y_triton:", y_triton)
print("y_fp16:", y_fp16)
print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))
y_torch: tensor([[ 51.3125, -48.5312, -18.3906, ..., -42.6562, -54.5938,
129.0000],
[ -34.3125, -60.8750, 25.5469, ..., 53.0312, 77.7500,
-21.8750],
[ 14.8750, 53.6875, -19.1875, ..., -3.6992, 64.8750,
102.0625],
...,
[ 163.2500, -10.9375, 33.8438, ..., -44.3438, -1.5117,
64.6250],
[ -15.1562, -2.1172, 14.7812, ..., -122.5625, -42.7500,
29.0469],
[ 1.1836, 55.1875, 68.0000, ..., -123.8125, 38.1250,
-48.8750]], device='cuda:0', dtype=torch.float16)
y_triton: tensor([[ 51.3125, -48.5000, -18.5000, ..., -42.6250, -54.5625,
128.8750],
[ -34.3125, -60.9062, 25.5625, ..., 53.0312, 77.7500,
-21.9062],
[ 14.8906, 53.6875, -19.2031, ..., -3.7031, 64.8750,
102.0625],
...,
[ 163.2500, -10.9141, 33.8750, ..., -44.3125, -1.4844,
64.6875],
[ -15.1719, -2.1543, 14.7812, ..., -122.6250, -42.7500,
29.0781],
[ 1.2188, 55.2500, 68.0000, ..., -123.7500, 38.1250,
-48.8750]], device='cuda:0', dtype=torch.float16)
y_fp16: tensor([[ 52.4688, -52.3438, -21.5625, ..., -41.5000, -54.0312,
127.6250],
[ -36.9375, -59.6250, 25.1406, ..., 54.1250, 74.6875,
-19.4062],
[ 13.5859, 50.5625, -23.1875, ..., -3.5859, 67.1250,
101.7500],
...,
[ 160.6250, -10.4297, 37.5000, ..., -41.0625, -1.8691,
69.6250],
[ -19.2344, -0.9331, 15.5234, ..., -123.1250, -40.7812,
31.0625],
[ -0.4199, 55.9062, 67.0625, ..., -123.2500, 36.1875,
-49.5625]], device='cuda:0', dtype=torch.float16)
fp16 vs torch cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)
fp16 vs triton cos_sim: tensor(0.9995, device='cuda:0', dtype=torch.float16)
Hi @AdnanHoque, I do some benchmark based on the script you posted in https://github.com/pytorch-labs/applied-ai/issues/21#issuecomment-2093716016, but can't reproduce the numbers mentioned in the blog https://pytorch.org/blog/accelerating-llama3/. The triton kernel is slower than torch._scaled_mm
. I do the benchmark on an H800
machine. For m=1, n=k=8192
, the triton kernel takes 0.00017833 s
, while torch._scaled_mm
only takes 3.337860107421875e-05 s
. Below is my benchmark script. Is there anything wrong with my script?
import torch
import triton
import triton.language as tl
import time
import os
from time import time
os.environ['ENABLE_TMA'] = '1'
@triton.jit
def grouped_launch(pid,
m, n,
block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):
grid_m = tl.cdiv(m, block_m)
grid_n = tl.cdiv(n, block_n)
width = group_m * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * group_m, group_m)
pid_m = group_id * group_m + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit
def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
scale_a, scale_b,
m, n, k,
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,
split_k: tl.constexpr, group_m: tl.constexpr):
pid = tl.program_id(0)
pid_k = tl.program_id(1)
grid_k = tl.cdiv(k, block_k*split_k)
pid_m, pid_n = grouped_launch(pid,
m, n,
block_m, block_n, group_m)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = pid_k*block_k + tl.arange(0, block_k)
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k):
k_remaining = k - k_ * (block_k * split_k)
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
a_ptrs += block_k * split_k * stride_ak
b_ptrs += block_k * split_k * stride_bk
acc = scale_a * scale_b * acc
acc.to(tl.float16)
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]
tl.atomic_add(c_ptrs, acc, mask=mask)
def gemm_split_k(a, b, scale_a:float=1.0, scale_b:float=1.0):
assert a.shape[1] == b.shape[0]
m, k = a.shape
_, n = b.shape
block_m = 64
block_n = 64
block_k = 512
num_stages = 3
num_warps = 8
split_k = 4
group_m = 8
total_blocks_m = triton.cdiv(m, block_m)
total_blocks_n = triton.cdiv(n, block_n)
total_programs_mn = total_blocks_m * total_blocks_n
total_programs_k = split_k
grid = (total_programs_mn, total_programs_k)
c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
k = gemm_split_k_kernel[grid](a, b, c,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
scale_a, scale_b,
m, n, k,
block_m, block_n, block_k,
split_k, group_m, num_stages=num_stages, num_warps=num_warps)
return c
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
scale = finfo.max / x.abs().max().clamp(min=1e-12)
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
dtype = torch.float16
qdtype = torch.float8_e4m3fn
torch.cuda.manual_seed(0)
m = 1
n = 8192
k = 8192
# create test inputs
x = torch.randn((m, k), dtype=dtype, device='cuda')
w = torch.randn((n, k), dtype=dtype, device='cuda')
x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)
for _ in range(10):
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
torch_start_time = time()
y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
torch_end_time = time()
print(f"torch duration: {torch_end_time - torch_start_time}")
triton_start_time = time()
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
triton_end_time = time()
print(f"triton duration: {triton_end_time - triton_start_time}")
y_fp16 = torch.nn.functional.linear(x, w)
#print("y_torch:", y_torch)
#print("y_triton:", y_triton)
#print("y_fp16:", y_fp16)
#print("fp16 vs torch cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_torch.reshape(-1), dim=0))
#print("fp16 vs triton cos_sim:", torch.nn.functional.cosine_similarity(y_fp16.reshape(-1), y_triton.reshape(-1), dim=0))
Hello @AdnanHoque , I am trying to recreate the results from the blog Accelerating Llama3 FP8 Inference with Triton Kernels. I haven't been able to get the splitk_gemm_fp8.py kernels to work properly as they seem to produce NaNs and Infs. I am using an H100 80GB (Driver 535.129.03, CUDA 12.2) with PyTorch 2.3.0. Do you have an example of the benchmark or accuracy eval used?
Here is the example and output demonstrating my issue.
Output:
Script used: