Open NiuMa-1234 opened 3 days ago
Please provide a minimal and complete reproducer that can be run without modification, i.e. it must include all relevant imports and kernel launching code.
Thank you for your reply, and here is a code snippet
import torch
import triton
import triton.language as tl
from triton import Config, cdiv
from typing import Callable
from triton.runtime import autotuner, driver
def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]:
config=Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4)
return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()})
@cce_backward_autotune()
@triton.heuristics(
{
"EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
"MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
"MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
"HAS_VALIDS": lambda args: args["Valids"] is not None,
"HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
"FILTER_GRAD": lambda args: args["filter_eps"] is not None,
"HAS_TARGETS": lambda args: args["Targets"] is not None,
"HAS_SOFTCAP": lambda args: args["softcap"] is not None,
"ITEM_DO": lambda args: args["dOut"].numel() == 1,
"GROUP_B": lambda args: 8,
}
)
@triton.jit
def _cce_backward_kernel(
E,
C,
LSE,
dOut,
grad_scale,
Valids,
VocabOrdering,
softcap,
Targets,
dE,
dELocks,
dC,
dCLocks,
B,
D,
V,
n_de_locks_0,
n_de_locks_1,
n_dc_locks_0,
n_dc_locks_1,
stride_eb,
stride_ed,
stride_cv,
stride_cd,
stride_vb,
filter_eps,
B_BIN,
BLOCK_B: tl.constexpr,
BLOCK_V: tl.constexpr,
BLOCK_D: tl.constexpr,
MM_BACK_BLOCK_D: tl.constexpr,
GROUP_B: tl.constexpr,
EVEN_D: tl.constexpr,
MM_BACK_EVEN_D: tl.constexpr,
ITEM_DO: tl.constexpr,
HAS_VALIDS: tl.constexpr,
HAS_VOCAB_ORDERING: tl.constexpr,
FILTER_GRAD: tl.constexpr,
HAS_TARGETS: tl.constexpr,
HAS_SOFTCAP: tl.constexpr,
SHIFT: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_b_chunks = tl.cdiv(B, BLOCK_B)
num_v_chunks = tl.cdiv(V, BLOCK_V)
num_v_in_group = GROUP_B * num_v_chunks
group_id = pid // num_v_in_group
first_pid_b = group_id * GROUP_B
group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b)
pid_v = (pid % num_v_in_group) // group_size_b
offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
if HAS_VALIDS:
offs_b = tl.load(Valids + stride_vb * offs_b)
offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
if HAS_VOCAB_ORDERING:
offs_v = tl.load(VocabOrdering + offs_v)
offs_d = tl.arange(0, BLOCK_D)
e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)
accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
if EVEN_D:
e = tl.load(e_ptrs)
c = tl.load(c_ptrs)
else:
e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)
accum = tl.dot(e, c, accum)
e_ptrs += BLOCK_D * stride_ed
c_ptrs += BLOCK_D * stride_cd
if HAS_SOFTCAP:
accum = tl_softcapping(accum, softcap)
if HAS_VALIDS:
lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B)
else:
lse = tl.load(LSE + offs_b)
d_accum = tl.exp(accum - lse[:, None])
if HAS_TARGETS:
targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b))
is_target = targets[:, None] == offs_v[None, :]
d_accum += tl.where(is_target, -1.0, 0.0)
else:
is_target = None
accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & (
(pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V
)
d_accum = tl.where(accum_valid_mask, d_accum, 0.0)
if FILTER_GRAD:
if _block_is_filtered(tl.abs(d_accum), filter_eps):
return
if HAS_SOFTCAP:
d_accum = tl_softcapping_grad(d_accum, accum, softcap)
if ITEM_DO:
d_out = tl.load(dOut)
else:
d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None]
d_out = grad_scale * d_out
d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty)
b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B
v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V
lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1
dELocks += lock_offset
_mm_backward(
d_accum,
dE + (offs_b[:, None] * stride_eb),
b_mask,
dELocks,
n_de_locks_1,
C + offs_v[:, None] * stride_cv,
v_mask,
stride_ed,
stride_cd,
D,
MM_BACK_BLOCK_D,
MM_BACK_EVEN_D,
)
lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1
dCLocks += lock_offset
_mm_backward(
tl.trans(d_accum),
dC + (offs_v[:, None] * stride_cv),
v_mask,
dCLocks,
n_dc_locks_1,
E + (offs_b[:, None] * stride_eb),
b_mask,
stride_cd,
stride_ed,
D,
MM_BACK_BLOCK_D,
MM_BACK_EVEN_D,
)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)
You're still missing the code to launch the kernel.
You're still missing the code to launch the kernel.
Sorry, I supplement the code to launch the kernel.
import torch
import triton
import triton.language as tl
from triton import Config, cdiv
from typing import Callable
from triton.runtime import autotuner, driver
def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]:
config=Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4)
return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()})
@cce_backward_autotune()
@triton.heuristics(
{
"EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
"MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
"MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
"HAS_VALIDS": lambda args: args["Valids"] is not None,
"HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
"FILTER_GRAD": lambda args: args["filter_eps"] is not None,
"HAS_TARGETS": lambda args: args["Targets"] is not None,
"HAS_SOFTCAP": lambda args: args["softcap"] is not None,
"ITEM_DO": lambda args: args["dOut"].numel() == 1,
"GROUP_B": lambda args: 8,
"BLOCK_B": lambda args: 128,
"BLOCK_V": lambda args: 128,
"BLOCK_D": lambda args: 32,
"num_warps": lambda args: 4,
"num_stages": lambda args: 4,
}
)
@triton.jit
def _cce_backward_kernel(
E,
C,
LSE,
dOut,
grad_scale,
Valids,
VocabOrdering,
softcap,
Targets,
dE,
dELocks,
dC,
dCLocks,
B,
D,
V,
n_de_locks_0,
n_de_locks_1,
n_dc_locks_0,
n_dc_locks_1,
stride_eb,
stride_ed,
stride_cv,
stride_cd,
stride_vb,
filter_eps,
B_BIN,
BLOCK_B: tl.constexpr,
BLOCK_V: tl.constexpr,
BLOCK_D: tl.constexpr,
MM_BACK_BLOCK_D: tl.constexpr,
GROUP_B: tl.constexpr,
EVEN_D: tl.constexpr,
MM_BACK_EVEN_D: tl.constexpr,
ITEM_DO: tl.constexpr,
HAS_VALIDS: tl.constexpr,
HAS_VOCAB_ORDERING: tl.constexpr,
FILTER_GRAD: tl.constexpr,
HAS_TARGETS: tl.constexpr,
HAS_SOFTCAP: tl.constexpr,
SHIFT: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_b_chunks = tl.cdiv(B, BLOCK_B)
num_v_chunks = tl.cdiv(V, BLOCK_V)
num_v_in_group = GROUP_B * num_v_chunks
group_id = pid // num_v_in_group
first_pid_b = group_id * GROUP_B
group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b)
pid_v = (pid % num_v_in_group) // group_size_b
offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
if HAS_VALIDS:
offs_b = tl.load(Valids + stride_vb * offs_b)
offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
if HAS_VOCAB_ORDERING:
offs_v = tl.load(VocabOrdering + offs_v)
offs_d = tl.arange(0, BLOCK_D)
e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)
accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
if EVEN_D:
e = tl.load(e_ptrs)
c = tl.load(c_ptrs)
else:
e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)
accum = tl.dot(e, c, accum)
e_ptrs += BLOCK_D * stride_ed
c_ptrs += BLOCK_D * stride_cd
if HAS_SOFTCAP:
accum = tl_softcapping(accum, softcap)
if HAS_VALIDS:
lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B)
else:
lse = tl.load(LSE + offs_b)
d_accum = tl.exp(accum - lse[:, None])
if HAS_TARGETS:
targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b))
is_target = targets[:, None] == offs_v[None, :]
d_accum += tl.where(is_target, -1.0, 0.0)
else:
is_target = None
accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & (
(pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V
)
d_accum = tl.where(accum_valid_mask, d_accum, 0.0)
if FILTER_GRAD:
if _block_is_filtered(tl.abs(d_accum), filter_eps):
return
if HAS_SOFTCAP:
d_accum = tl_softcapping_grad(d_accum, accum, softcap)
if ITEM_DO:
d_out = tl.load(dOut)
else:
d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None]
d_out = grad_scale * d_out
d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty)
b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B
v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V
lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1
dELocks += lock_offset
_mm_backward(
d_accum,
dE + (offs_b[:, None] * stride_eb),
b_mask,
dELocks,
n_de_locks_1,
C + offs_v[:, None] * stride_cv,
v_mask,
stride_ed,
stride_cd,
D,
MM_BACK_BLOCK_D,
MM_BACK_EVEN_D,
)
lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1
dCLocks += lock_offset
_mm_backward(
tl.trans(d_accum),
dC + (offs_v[:, None] * stride_cv),
v_mask,
dCLocks,
n_dc_locks_1,
E + (offs_b[:, None] * stride_eb),
b_mask,
stride_cd,
stride_ed,
D,
MM_BACK_BLOCK_D,
MM_BACK_EVEN_D,
)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)
def launch_cce_backward_kernel(
E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets, dE, dELocks,
dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1, n_dc_locks_0, n_dc_locks_1,
stride_eb, stride_ed, stride_cv, stride_cd, filter_eps, B_BIN,
BLOCK_B=128, BLOCK_V=128, BLOCK_D=32, SHIFT=False
):
# Compute grid size
num_b_chunks = (B + BLOCK_B - 1) // BLOCK_B
num_v_chunks = (V + BLOCK_V - 1) // BLOCK_V
GROUP_B = 8
num_programs = num_b_chunks * num_v_chunks * GROUP_B
# Launch the kernel
_cce_backward_kernel[
num_programs # Grid size
](
E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets, dE, dELocks,
dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1, n_dc_locks_0, n_dc_locks_1,
stride_eb, stride_ed, stride_cv, stride_cd, filter_eps, B_BIN,
BLOCK_B=BLOCK_B, BLOCK_V=BLOCK_V, BLOCK_D=BLOCK_D,
MM_BACK_BLOCK_D=BLOCK_D * 2, GROUP_B=GROUP_B,
EVEN_D=(D % BLOCK_D == 0), MM_BACK_EVEN_D=(D % (BLOCK_D * 2) == 0),
ITEM_DO=(dOut.numel() == 1), HAS_VALIDS=(Valids is not None),
HAS_VOCAB_ORDERING=(VocabOrdering is not None),
FILTER_GRAD=(filter_eps is not None), HAS_TARGETS=(Targets is not None),
HAS_SOFTCAP=(softcap is not None), SHIFT=SHIFT
)
# Example inputs
B, D, V = 1024, 64, 512
E = torch.rand(B, D, device='cuda', dtype=torch.float16)
C = torch.rand(V, D, device='cuda', dtype=torch.float16)
LSE = torch.rand(B, device='cuda', dtype=torch.float16)
dOut = torch.rand(B, device='cuda', dtype=torch.float16)
grad_scale = torch.tensor(1.0, device='cuda', dtype=torch.float16)
Valids = None
VocabOrdering = None
softcap = None
Targets = None
dE = torch.zeros_like(E, device='cuda', dtype=torch.float16)
dELocks = torch.zeros((8, 8), device='cuda', dtype=torch.int32)
dC = torch.zeros_like(C, device='cuda', dtype=torch.float16)
dCLocks = torch.zeros((8, 8), device='cuda', dtype=torch.int32)
n_de_locks_0, n_de_locks_1 = 8, 8
n_dc_locks_0, n_dc_locks_1 = 8, 8
stride_eb, stride_ed = D, 1
stride_cv, stride_cd = D, 1
filter_eps = None
B_BIN = None
launch_cce_backward_kernel(
E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets,
dE, dELocks, dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1,
n_dc_locks_0, n_dc_locks_1, stride_eb, stride_ed, stride_cv, stride_cd,
filter_eps, B_BIN
)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)
Why are you applying each decorator twice, once in the order jit
, heuristics
, cce_backward_autotune
and then this block applies the same decorators in the reverse order?
When I run this on main triton.jit
complains that it's not being passed a callable which is probably why the .src
attribute is wrong.
Also if I remove these lines the code still fails to run with the error
TypeError: dynamic_func() missing 1 required positional argument: 'B_BIN'
which suggests the arguments aren't being passed correctly so there's likely something wrong with your code.
Describe the bug
Hi, I encountered an error when implementing a triton.jit function and its seems to be caused by multiple decorators. Below is the detailed error and the code. Could you please help me?
The error info:
The code: ( two parts: the _cce_backward_kernel and one of its decorator cce_backward_autotune)
cce_backward_autotune :
`
I check the content of the self.src reported in the error info, and it's like this, which only contains the decorator: `
Environment details
Triton: 3.1.0
GPU: A800-SXM
PyTorch: 2.5.1
python: 3.10.9