triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.5k stars 1.67k forks source link

AttributeError when implementing a triton.jit func with multiple decorators #5224

Open NiuMa-1234 opened 3 days ago

NiuMa-1234 commented 3 days ago

Describe the bug

Hi, I encountered an error when implementing a triton.jit function and its seems to be caused by multiple decorators. Below is the detailed error and the code. Could you please help me?

The error info:

    def _cce_backward_kernel(
  File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 885, in jit
    return decorator(fn)
  File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 874, in decorator
    return JITFunction(
  File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 720, in __init__
    **self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'**

The code: ( two parts: the _cce_backward_kernel and one of its decorator cce_backward_autotune)

@cce_backward_autotune()
@triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,
    }
)
@triton.jit
def _cce_backward_kernel(
    E,
    C,
    LSE,
    dOut,
    grad_scale,
    Valids,
    VocabOrdering,
    softcap,
    Targets,
    dE,
    dELocks,
    dC,
    dCLocks,
    B,
    D,
    V,
    n_de_locks_0,
    n_de_locks_1,
    n_dc_locks_0,
    n_dc_locks_1,
    stride_eb,
    stride_ed,
    stride_cv,
    stride_cd,
    stride_vb,
    filter_eps,
    B_BIN,
    BLOCK_B: tl.constexpr,
    BLOCK_V: tl.constexpr,
    BLOCK_D: tl.constexpr,
    MM_BACK_BLOCK_D: tl.constexpr,
    GROUP_B: tl.constexpr,
    EVEN_D: tl.constexpr,
    MM_BACK_EVEN_D: tl.constexpr,
    ITEM_DO: tl.constexpr,
    HAS_VALIDS: tl.constexpr,
    HAS_VOCAB_ORDERING: tl.constexpr,
    FILTER_GRAD: tl.constexpr,
    HAS_TARGETS: tl.constexpr,
    HAS_SOFTCAP: tl.constexpr,
    SHIFT: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_b_chunks = tl.cdiv(B, BLOCK_B)
    num_v_chunks = tl.cdiv(V, BLOCK_V)
    num_v_in_group = GROUP_B * num_v_chunks
    group_id = pid // num_v_in_group
    first_pid_b = group_id * GROUP_B
    group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
    pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b)
    pid_v = (pid % num_v_in_group) // group_size_b

    offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
    if HAS_VALIDS:
        offs_b = tl.load(Valids + stride_vb * offs_b)

    offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
    if HAS_VOCAB_ORDERING:
        offs_v = tl.load(VocabOrdering + offs_v)

    offs_d = tl.arange(0, BLOCK_D)
    e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
    c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)

    accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
    for d in range(0, tl.cdiv(D, BLOCK_D)):
        if EVEN_D:
            e = tl.load(e_ptrs)
            c = tl.load(c_ptrs)
        else:
            e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
            c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)

        accum = tl.dot(e, c, accum)

        e_ptrs += BLOCK_D * stride_ed
        c_ptrs += BLOCK_D * stride_cd

    if HAS_SOFTCAP:
        accum = tl_softcapping(accum, softcap)

    if HAS_VALIDS:
        lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B)
    else:
        lse = tl.load(LSE + offs_b)

    d_accum = tl.exp(accum - lse[:, None])

    if HAS_TARGETS:
        targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b))
        is_target = targets[:, None] == offs_v[None, :]
        d_accum += tl.where(is_target, -1.0, 0.0)
    else:
        is_target = None

    accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & (
        (pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V
    )
    d_accum = tl.where(accum_valid_mask, d_accum, 0.0)

    if FILTER_GRAD:
        if _block_is_filtered(tl.abs(d_accum), filter_eps):
            return

    if HAS_SOFTCAP:
        d_accum = tl_softcapping_grad(d_accum, accum, softcap)

    if ITEM_DO:
        d_out = tl.load(dOut)
    else:
        d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None]

    d_out = grad_scale * d_out

    d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty)

    b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B
    v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V

    lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1
    dELocks += lock_offset

    _mm_backward(
        d_accum,
        dE + (offs_b[:, None] * stride_eb),
        b_mask,
        dELocks,
        n_de_locks_1,
        C + offs_v[:, None] * stride_cv,
        v_mask,
        stride_ed,
        stride_cd,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )

    lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1
    dCLocks += lock_offset

    _mm_backward(
        tl.trans(d_accum),
        dC + (offs_v[:, None] * stride_cv),
        v_mask,
        dCLocks,
        n_dc_locks_1,
        E + (offs_b[:, None] * stride_eb),
        b_mask,
        stride_cd,
        stride_ed,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )

cce_backward_autotune :

def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]:
   kwargs = Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4)
   return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()})

`

I check the content of the self.src reported in the error info, and it's like this, which only contains the decorator: `

@cce_backward_autotune()
@triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,
    }
)

Environment details

Triton: 3.1.0

GPU: A800-SXM

PyTorch: 2.5.1

python: 3.10.9

peterbell10 commented 3 days ago

Please provide a minimal and complete reproducer that can be run without modification, i.e. it must include all relevant imports and kernel launching code.

NiuMa-1234 commented 22 hours ago

Thank you for your reply, and here is a code snippet

import torch
import triton
import triton.language as tl
from triton import Config, cdiv
from typing import Callable
from triton.runtime import autotuner, driver

def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]:
    config=Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4)
    return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()})

@cce_backward_autotune()
@triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,
    }
)
@triton.jit
def _cce_backward_kernel(
    E,
    C,
    LSE,
    dOut,
    grad_scale,
    Valids,
    VocabOrdering,
    softcap,
    Targets,
    dE,
    dELocks,
    dC,
    dCLocks,
    B,
    D,
    V,
    n_de_locks_0,
    n_de_locks_1,
    n_dc_locks_0,
    n_dc_locks_1,
    stride_eb,
    stride_ed,
    stride_cv,
    stride_cd,
    stride_vb,
    filter_eps,
    B_BIN,
    BLOCK_B: tl.constexpr,
    BLOCK_V: tl.constexpr,
    BLOCK_D: tl.constexpr,
    MM_BACK_BLOCK_D: tl.constexpr,
    GROUP_B: tl.constexpr,
    EVEN_D: tl.constexpr,
    MM_BACK_EVEN_D: tl.constexpr,
    ITEM_DO: tl.constexpr,
    HAS_VALIDS: tl.constexpr,
    HAS_VOCAB_ORDERING: tl.constexpr,
    FILTER_GRAD: tl.constexpr,
    HAS_TARGETS: tl.constexpr,
    HAS_SOFTCAP: tl.constexpr,
    SHIFT: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_b_chunks = tl.cdiv(B, BLOCK_B)
    num_v_chunks = tl.cdiv(V, BLOCK_V)
    num_v_in_group = GROUP_B * num_v_chunks
    group_id = pid // num_v_in_group
    first_pid_b = group_id * GROUP_B
    group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
    pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b)
    pid_v = (pid % num_v_in_group) // group_size_b

    offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
    if HAS_VALIDS:
        offs_b = tl.load(Valids + stride_vb * offs_b)

    offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
    if HAS_VOCAB_ORDERING:
        offs_v = tl.load(VocabOrdering + offs_v)

    offs_d = tl.arange(0, BLOCK_D)
    e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
    c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)

    accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
    for d in range(0, tl.cdiv(D, BLOCK_D)):
        if EVEN_D:
            e = tl.load(e_ptrs)
            c = tl.load(c_ptrs)
        else:
            e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
            c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)

        accum = tl.dot(e, c, accum)

        e_ptrs += BLOCK_D * stride_ed
        c_ptrs += BLOCK_D * stride_cd

    if HAS_SOFTCAP:
        accum = tl_softcapping(accum, softcap)

    if HAS_VALIDS:
        lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B)
    else:
        lse = tl.load(LSE + offs_b)

    d_accum = tl.exp(accum - lse[:, None])

    if HAS_TARGETS:
        targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b))
        is_target = targets[:, None] == offs_v[None, :]
        d_accum += tl.where(is_target, -1.0, 0.0)
    else:
        is_target = None

    accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & (
        (pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V
    )
    d_accum = tl.where(accum_valid_mask, d_accum, 0.0)

    if FILTER_GRAD:
        if _block_is_filtered(tl.abs(d_accum), filter_eps):
            return

    if HAS_SOFTCAP:
        d_accum = tl_softcapping_grad(d_accum, accum, softcap)

    if ITEM_DO:
        d_out = tl.load(dOut)
    else:
        d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None]

    d_out = grad_scale * d_out

    d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty)

    b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B
    v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V

    lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1
    dELocks += lock_offset

    _mm_backward(
        d_accum,
        dE + (offs_b[:, None] * stride_eb),
        b_mask,
        dELocks,
        n_de_locks_1,
        C + offs_v[:, None] * stride_cv,
        v_mask,
        stride_ed,
        stride_cd,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )

    lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1
    dCLocks += lock_offset

    _mm_backward(
        tl.trans(d_accum),
        dC + (offs_v[:, None] * stride_cv),
        v_mask,
        dCLocks,
        n_dc_locks_1,
        E + (offs_b[:, None] * stride_eb),
        b_mask,
        stride_cd,
        stride_ed,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)
peterbell10 commented 19 hours ago

You're still missing the code to launch the kernel.

NiuMa-1234 commented 18 hours ago

You're still missing the code to launch the kernel.

Sorry, I supplement the code to launch the kernel.

import torch
import triton
import triton.language as tl
from triton import Config, cdiv
from typing import Callable
from triton.runtime import autotuner, driver

def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]:
    config=Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4)
    return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()})

@cce_backward_autotune()
@triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,

        "BLOCK_B": lambda args: 128,
        "BLOCK_V": lambda args: 128,
        "BLOCK_D": lambda args: 32,
        "num_warps": lambda args: 4,
        "num_stages": lambda args: 4,
    }
)
@triton.jit
def _cce_backward_kernel(
    E,
    C,
    LSE,
    dOut,
    grad_scale,
    Valids,
    VocabOrdering,
    softcap,
    Targets,
    dE,
    dELocks,
    dC,
    dCLocks,
    B,
    D,
    V,
    n_de_locks_0,
    n_de_locks_1,
    n_dc_locks_0,
    n_dc_locks_1,
    stride_eb,
    stride_ed,
    stride_cv,
    stride_cd,
    stride_vb,
    filter_eps,
    B_BIN,
    BLOCK_B: tl.constexpr,
    BLOCK_V: tl.constexpr,
    BLOCK_D: tl.constexpr,
    MM_BACK_BLOCK_D: tl.constexpr,
    GROUP_B: tl.constexpr,
    EVEN_D: tl.constexpr,
    MM_BACK_EVEN_D: tl.constexpr,
    ITEM_DO: tl.constexpr,
    HAS_VALIDS: tl.constexpr,
    HAS_VOCAB_ORDERING: tl.constexpr,
    FILTER_GRAD: tl.constexpr,
    HAS_TARGETS: tl.constexpr,
    HAS_SOFTCAP: tl.constexpr,
    SHIFT: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_b_chunks = tl.cdiv(B, BLOCK_B)
    num_v_chunks = tl.cdiv(V, BLOCK_V)
    num_v_in_group = GROUP_B * num_v_chunks
    group_id = pid // num_v_in_group
    first_pid_b = group_id * GROUP_B
    group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
    pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b)
    pid_v = (pid % num_v_in_group) // group_size_b

    offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
    if HAS_VALIDS:
        offs_b = tl.load(Valids + stride_vb * offs_b)

    offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
    if HAS_VOCAB_ORDERING:
        offs_v = tl.load(VocabOrdering + offs_v)

    offs_d = tl.arange(0, BLOCK_D)
    e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
    c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)

    accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
    for d in range(0, tl.cdiv(D, BLOCK_D)):
        if EVEN_D:
            e = tl.load(e_ptrs)
            c = tl.load(c_ptrs)
        else:
            e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
            c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)

        accum = tl.dot(e, c, accum)

        e_ptrs += BLOCK_D * stride_ed
        c_ptrs += BLOCK_D * stride_cd

    if HAS_SOFTCAP:
        accum = tl_softcapping(accum, softcap)

    if HAS_VALIDS:
        lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B)
    else:
        lse = tl.load(LSE + offs_b)

    d_accum = tl.exp(accum - lse[:, None])

    if HAS_TARGETS:
        targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b))
        is_target = targets[:, None] == offs_v[None, :]
        d_accum += tl.where(is_target, -1.0, 0.0)
    else:
        is_target = None

    accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & (
        (pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V
    )
    d_accum = tl.where(accum_valid_mask, d_accum, 0.0)

    if FILTER_GRAD:
        if _block_is_filtered(tl.abs(d_accum), filter_eps):
            return

    if HAS_SOFTCAP:
        d_accum = tl_softcapping_grad(d_accum, accum, softcap)

    if ITEM_DO:
        d_out = tl.load(dOut)
    else:
        d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None]

    d_out = grad_scale * d_out

    d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty)

    b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B
    v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V

    lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1
    dELocks += lock_offset

    _mm_backward(
        d_accum,
        dE + (offs_b[:, None] * stride_eb),
        b_mask,
        dELocks,
        n_de_locks_1,
        C + offs_v[:, None] * stride_cv,
        v_mask,
        stride_ed,
        stride_cd,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )

    lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1
    dCLocks += lock_offset

    _mm_backward(
        tl.trans(d_accum),
        dC + (offs_v[:, None] * stride_cv),
        v_mask,
        dCLocks,
        n_dc_locks_1,
        E + (offs_b[:, None] * stride_eb),
        b_mask,
        stride_cd,
        stride_ed,
        D,
        MM_BACK_BLOCK_D,
        MM_BACK_EVEN_D,
    )
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)

def launch_cce_backward_kernel(
    E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets, dE, dELocks, 
    dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1, n_dc_locks_0, n_dc_locks_1, 
    stride_eb, stride_ed, stride_cv, stride_cd, filter_eps, B_BIN, 
    BLOCK_B=128, BLOCK_V=128, BLOCK_D=32, SHIFT=False
):
    # Compute grid size
    num_b_chunks = (B + BLOCK_B - 1) // BLOCK_B
    num_v_chunks = (V + BLOCK_V - 1) // BLOCK_V
    GROUP_B = 8
    num_programs = num_b_chunks * num_v_chunks * GROUP_B

    # Launch the kernel
    _cce_backward_kernel[
        num_programs  # Grid size
    ](
        E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets, dE, dELocks,
        dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1, n_dc_locks_0, n_dc_locks_1, 
        stride_eb, stride_ed, stride_cv, stride_cd, filter_eps, B_BIN,
        BLOCK_B=BLOCK_B, BLOCK_V=BLOCK_V, BLOCK_D=BLOCK_D, 
        MM_BACK_BLOCK_D=BLOCK_D * 2, GROUP_B=GROUP_B, 
        EVEN_D=(D % BLOCK_D == 0), MM_BACK_EVEN_D=(D % (BLOCK_D * 2) == 0),
        ITEM_DO=(dOut.numel() == 1), HAS_VALIDS=(Valids is not None), 
        HAS_VOCAB_ORDERING=(VocabOrdering is not None), 
        FILTER_GRAD=(filter_eps is not None), HAS_TARGETS=(Targets is not None), 
        HAS_SOFTCAP=(softcap is not None), SHIFT=SHIFT
    )

# Example inputs
B, D, V = 1024, 64, 512
E = torch.rand(B, D, device='cuda', dtype=torch.float16)
C = torch.rand(V, D, device='cuda', dtype=torch.float16)
LSE = torch.rand(B, device='cuda', dtype=torch.float16)
dOut = torch.rand(B, device='cuda', dtype=torch.float16)
grad_scale = torch.tensor(1.0, device='cuda', dtype=torch.float16)

Valids = None
VocabOrdering = None
softcap = None
Targets = None
dE = torch.zeros_like(E, device='cuda', dtype=torch.float16)
dELocks = torch.zeros((8, 8), device='cuda', dtype=torch.int32)
dC = torch.zeros_like(C, device='cuda', dtype=torch.float16)
dCLocks = torch.zeros((8, 8), device='cuda', dtype=torch.int32)

n_de_locks_0, n_de_locks_1 = 8, 8
n_dc_locks_0, n_dc_locks_1 = 8, 8
stride_eb, stride_ed = D, 1
stride_cv, stride_cd = D, 1
filter_eps = None
B_BIN = None

launch_cce_backward_kernel(
    E, C, LSE, dOut, grad_scale, Valids, VocabOrdering, softcap, Targets, 
    dE, dELocks, dC, dCLocks, B, D, V, n_de_locks_0, n_de_locks_1, 
    n_dc_locks_0, n_dc_locks_1, stride_eb, stride_ed, stride_cv, stride_cd, 
    filter_eps, B_BIN
)
peterbell10 commented 9 hours ago
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(...)(_cce_backward_kernel)
_cce_backward_kernel = triton.jit(_cce_backward_kernel)

Why are you applying each decorator twice, once in the order jit, heuristics, cce_backward_autotune and then this block applies the same decorators in the reverse order?

When I run this on main triton.jit complains that it's not being passed a callable which is probably why the .src attribute is wrong.

Also if I remove these lines the code still fails to run with the error

TypeError: dynamic_func() missing 1 required positional argument: 'B_BIN'

which suggests the arguments aren't being passed correctly so there's likely something wrong with your code.