triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.27k stars 1.63k forks source link

FlexAttention Segmentation Fault #4521

Open drisspg opened 2 months ago

drisspg commented 2 months ago

Summary

I have a PR to enable non power of 2 head_dim for FlexAttention. This impl does this via padding up to the next power of 2 and masking the loads and stores. https://github.com/pytorch/pytorch/pull/133495

I am encountering a segfault only when the head_dim is even, when it is odd this strategy works correctly. This is on an H100 machine. lldb stacktrace:

* thread #1, name = 'pt_main_thread', stop reason = signal SIGSEGV: address not mapped to object (fault address: 0x10)
    frame #0: 0x00007ffed367fbfe libtriton.so`scheduleRemainingToLastStage(forOp=ForOp @ 0x00007fffffffbcf8, schedule=0x00007fffffffc080, afterPrologue=<unavailable>, numStages=2) at MatmulLoopPipeline.cpp:893:9
(lldb) bt
* thread #1, name = 'pt_main_thread', stop reason = signal SIGSEGV: address not mapped to object (fault address: 0x10)
  * frame #0: 0x00007ffed367fbfe libtriton.so`scheduleRemainingToLastStage(forOp=ForOp @ 0x00007fffffffbcf8, schedule=0x00007fffffffc080, afterPrologue=<unavailable>, numStages=2) at MatmulLoopPipeline.cpp:893:9
    frame #1: 0x00007ffed368d970 libtriton.so`mlir::triton::preProcessLoopAndGetSchedule(forOp=0x00007fffffffc460, numStages=2, options=0x00007fffffffc520) at MatmulLoopPipeline.cpp:1230:31
    frame #2: 0x00007ffed36a6a43 libtriton.so`mlir::triton::gpu::PipelinePass::runOnOperation() [inlined] pipelineLoop(numStages=2, forOp=ForOp @ 0x00007fffffffc460) at SoftwarePipeliner.cpp:79:47
    frame #3: 0x00007ffed36a6998 libtriton.so`mlir::triton::gpu::PipelinePass::runOnOperation(this=0x000000000a165e60) at SoftwarePipeliner.cpp:125:36
    frame #4: 0x00007ffed3c5147c libtriton.so`mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 700
    frame #5: 0x00007ffed3c51df2 libtriton.so`mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 354
    frame #6: 0x00007ffed3c5481c libtriton.so`mlir::PassManager::run(mlir::Operation*) + 876
    frame #7: 0x00007ffed3942bad libtriton.so`<lambda(mlir::PassManager&, mlir::ModuleOp&)>::operator(self=<unavailable>, mod=0x000000000ad8d920, __closure=<unavailable>)(mlir::PassManager &, mlir::ModuleOp &) at ir.cc:1625:19
    frame #8: 0x00007ffed3960108 libtriton.so`_FUN [inlined] operator(this=0x0000000000000000, call=0x00007fffffffcd80) at cast.h:1480:37
    frame #9: 0x00007ffed39600f0 libtriton.so`_FUN((null)=0x00007fffffffcd80) at pybind11.h:224:21
    frame #10: 0x00007ffed9ee5590 libtriton.so`typeinfo for pybind11::handle + 24
    frame #11: 0x00007ffed9ee5590 libtriton.so`typeinfo for pybind11::handle + 24

Repro, on pytorch Nightly:

# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool

# kernel path: /tmp/torchinductor_drisspg/76/c76vxhoeyr34oca27x7ybxq6orqpjucbsyxyxgtfkdyuibultz5r.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
#   %full_default_4 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 8, 128], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph, %joint_graph, (%full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 128, 128, %mask_graph), 0.08944271909999159, {ROWS_GUARANTEED_SAFE: False, PRESCALE_QK: False, OUTPUT_LOGSUMEXP: True, IS_DIVISIBLE: True}, (), ()), kwargs = {})
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton.jit
def triton_per_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 4096
    rnumel = 126
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (126*x0)), rmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (126*x0)), rmask, other=0.0).to(tl.float32)
    tmp2 = tmp0 * tmp1
    tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
    tmp5 = tl.where(rmask, tmp3, 0)
    tmp6 = tl.sum(tmp5, 1)[:, None]
    tmp7 = tmp6.to(tl.float32)
    tmp8 = 0.0
    tmp9 = tmp7 - tmp8
    tl.store(out_ptr1 + (x0), tmp9, None)

import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

# kernel path: /tmp/torchinductor_drisspg/ho/choddawatwsj2tif3t6ixv5pzlpm4fxw3mn35w6ytgd24xjg6mrj.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
#   %full_default_4 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 8, 128], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph, %joint_graph, (%full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 128, 128, %mask_graph), 0.08944271909999159, {ROWS_GUARANTEED_SAFE: False, PRESCALE_QK: False, OUTPUT_LOGSUMEXP: True, IS_DIVISIBLE: True}, (), ()), kwargs = {})
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton.jit
def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    PRESCALE_QK : tl.constexpr = False
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.08944271909999159
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = False
    BLOCK_DMODEL : tl.constexpr = 126
    BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    Q = arg_Q
    K = arg_K
    V = arg_V
    LSE = arg_LSE
    DELTA = arg_DELTA
    DO = arg_DO
    DQ = arg_DQ
    DV = arg_DV
    KV_NUM_BLKS = arg_KV_NUM_BLKS
    KV_IDX = arg_KV_IDX
    Q_NUM_BLKS = arg_Q_NUM_BLKS
    Q_IDX = arg_Q_IDX
    FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
    FULL_KV_IDX = arg_FULL_KV_IDX
    FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
    FULL_Q_IDX = arg_FULL_Q_IDX

    # Sub notation for this kernel:
    #
    # Q: Query, K: Key, V: Value
    # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
    # DELTA: Precomputed sum(OUT*DO, axis=-1)
    # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
    # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
    # inductor codegen
    # M: Number of queries, N: Number of keys/values, D: Model dimension
    # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
    # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
    # (Modifiable) Performance tuning options
    # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
    # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
    # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
    # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
    #
    # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
    # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
    # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
    # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
    # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
    # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
    # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.

    # The below are kernel options that can be applied for certain score_mods,
    # or involve a numerics vs. perf tradeoff
    # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
    # about 20% more numerical error, but slightly faster.

    # Define strides of inputs
    stride_qz, stride_qh, stride_qm, stride_qd = 128000, 16128, 126, 1
    stride_kz, stride_kh, stride_kn, stride_kd = 128000, 16128, 126, 1
    stride_vz, stride_vh, stride_vn, stride_vd = 128000, 16128, 126, 1
    stride_doz, stride_doh, stride_dom, stride_dod = 128000, 16128, 126, 1

    stride_dqz, stride_dqh, stride_dqm, stride_dqd = 128000, 16128, 126, 1
    stride_dvz, stride_dvh, stride_dvm, stride_dvd = 128000, 16128, 126, 1

    Z = 4
    HQ = 8
    HKV = 8
    Q_LEN = 128
    KV_LEN = 128

    MATMUL_PRECISION = Q.dtype.element_ty

    pid = tl.program_id(0)
    NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
    NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)

    off_hz = tl.program_id(2)
    off_z = off_hz // HKV # batch idx
    off_hkv = off_hz % HKV # kv head idx

    SPARSE_Z = 1
    SPARSE_HQ = 1

    sparse_idx_z = off_z % SPARSE_Z

    k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
    v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
    dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)

    # offset K, V, DV pointers for batch/kv-head
    K += k_adj
    V += v_adj
    DV += dv_adj

    if IS_DIVISIBLE:
        tl.static_assert(BLOCK_DMODEL == BLOCK_DMODEL_ROUNDED)

    RCP_LN2 = 1.44269504
    offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)

    if pid >= NUM_KV_BLOCKS:
        off_pid = pid - NUM_KV_BLOCKS
        # THIS BLOCK DOES DQ
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
        off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
        start_m2_block = off_pid % NUM_Q_BLOCKS
        off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
        stride_kv_num_blks_h = 1
        stride_kv_idx_h = 1
        stride_kv_idx_m = 1

        sparse_idx_hq2 = off_hq2 % SPARSE_HQ
        sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2

        sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
        sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m  # noqa: B950

        # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
        q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
        do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
        dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
        off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)

        Q2 = Q + q_adj2
        DO2 = DO + do_adj2
        # TODO: This does not work if DQ is not the same layout as Q (for example,
        # if Q is broadcasted)
        DQ2 = DQ + dq_adj2
        LSE2 = LSE + off_chz2
        DELTA2 = DELTA + off_chz2

        dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)

        start_m2 = start_m2_block * BLOCK_M2
        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)

        # load Q and do: they stay in SRAM throughout the inner loop.
        if IS_DIVISIBLE:
            q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
            do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
        else:
            q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
            do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))

        if PRESCALE_QK:
            q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        if IS_DIVISIBLE:
            Di = tl.load(DELTA2 + offs_m2)
            lse = tl.load(LSE2 + offs_m2)
        else:
            Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
            lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
        lse = lse[:, None]

        # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # KV_IDX and KV_NUM_BLKS are always contiguous.
        kv_indices = KV_IDX + sparse_kv_idx_offset
        kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
        sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)

        offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
        dq = bwd_dq_inner(
            K, V,
            dq, q, do, Di, lse,
            off_z, off_hq2, offs_m2, offs_n2,
            stride_kn, stride_kd, stride_vn, stride_vd,
            kv_indices, sparse_kv_num_blocks,
            MATMUL_PRECISION,
            arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
            IS_FULL_BLOCKS=False
        )

        if HAS_FULL_BLOCKS:
            # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
            kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
            kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
            sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)

            offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
            dq = bwd_dq_inner(
                K, V,
                dq, q, do, Di, lse,
                off_z, off_hq2, offs_m2, offs_n2,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                IS_FULL_BLOCKS=True
            )

        # Write back dQ.
        dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
        dq *= SM_SCALE
        if IS_DIVISIBLE:
            tl.store(dq_ptrs, dq)
        else:
            tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
    else:
        # THIS BLOCK DOES DK & DV
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)

        pid_mask = pid // SPARSE_KV_MULTIPLE

        stride_q_num_blks_h = 1
        stride_q_idx_h = 1
        stride_q_idx_n = 1

        dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)
        dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)

        start_n1 = pid * BLOCK_N1
        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)

        # load K and V: they stay in SRAM throughout the inner loop.
        if IS_DIVISIBLE:
            k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
            v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd)
        else:
            k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=(offs_n1[:, None] < KV_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
            v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd, mask=(offs_n1[:, None] < KV_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
        if PRESCALE_QK:
            k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        for off_g in range(0, GQA_SHARED_HEADS):
            off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g

            # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
            q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
            do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
            dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
            off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)

            Q1 = Q + q_adj1
            DO1 = DO + do_adj1
            # TODO: This does not work if DQ is not the same layout as Q (for example,
            # if Q is broadcasted)
            LSE1 = LSE + off_chz1
            DELTA1 = DELTA + off_chz1

            sparse_idx_hq1 = off_hq1 % SPARSE_HQ
            sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1

            sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
            sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n  # noqa: B950

            # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Q_IDX and Q_NUM_BLKS are always contiguous.
            q_indices = Q_IDX + sparse_q_idx_offset
            q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
            sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)

            offs_m1 = q_start + tl.arange(0, BLOCK_M1)
            dk, dv = bwd_dkdv_inner(
                Q1, DO1, DELTA1, LSE1,
                dk, dv, k, v,
                off_z, off_hq1, offs_n1, offs_m1,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                IS_FULL_BLOCKS=False
            )

            if HAS_FULL_BLOCKS:
                # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
                q_indices = FULL_Q_IDX + sparse_q_idx_offset
                q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
                sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)

                offs_m1 = q_start + tl.arange(0, BLOCK_M1)
                dk, dv = bwd_dkdv_inner(
                    Q1, DO1, DELTA1, LSE1,
                    dk, dv, k, v,
                    off_z, off_hq1, offs_n1, offs_m1,
                    stride_qm, stride_qd, stride_dom, stride_dod,
                    q_indices, sparse_q_num_blocks,
                    MATMUL_PRECISION,
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                    IS_FULL_BLOCKS=True
                )

        # Write back dV and dK.
        dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd

        index_n = offs_n1[:, None]
        index_k = offs_k[None, :]

        if IS_DIVISIBLE:
            tl.store(dv_ptrs, dv)
        else:
            tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_k < BLOCK_DMODEL))

        dk *= SM_SCALE
        mask = (index_n < KV_LEN) & (index_k < BLOCK_DMODEL)
        xindex = index_k + (126*index_n) + (16128*off_hkv) + (128000*off_z)
        tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)

@triton.jit
def bwd_dq_inner(
    K, V,  # pointers
    dq, q, do, Di, lse,
    off_z, off_hq, offs_m2, offs_n2,
    stride_kn, stride_kd, stride_vn, stride_vd,
    kv_indices, sparse_kv_num_blocks,
    MATMUL_PRECISION,
    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
):
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    PRESCALE_QK : tl.constexpr = False
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.08944271909999159
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = False
    BLOCK_DMODEL : tl.constexpr = 126
    BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128

    SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
    RCP_LN2: tl.constexpr = 1.44269504
    Q_LEN = 128
    KV_LEN = 128

    offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)

    kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
    vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_k[:, None] * stride_vd
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)

    hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
    if not IS_DIVISIBLE:
        if hi >= 1:
            for start_n in range(0, hi - 1):
                dq = bwd_dq_compute_block_mn(
                    dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                    off_z, off_hq, offs_m2, offs_n2, offs_k,
                    stride_kn, stride_kd, stride_vn, stride_vd,
                    kv_indices, sparse_kv_num_blocks,
                    MATMUL_PRECISION, RCP_LN2,
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
                )

                # Increment pointers.
                offset = get_offset_for_next_block(
                    start_n, kv_indices, sparse_kv_num_blocks,
                    SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
                )

                kT_ptrs += offset * stride_kn
                vT_ptrs += offset * stride_vn

                offs_n2 += offset

            dq = bwd_dq_compute_block_mn(
                dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                off_z, off_hq, offs_m2, offs_n2, offs_k,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, True,
            )
    else:
        for start_n in range(0, hi):
            dq = bwd_dq_compute_block_mn(
                dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
                off_z, off_hq, offs_m2, offs_n2, offs_k,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
            )

            # Increment pointers.
            offset = get_offset_for_next_block(
                start_n, kv_indices, sparse_kv_num_blocks,
                SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
            )

            kT_ptrs += offset * stride_kn
            vT_ptrs += offset * stride_vn

            offs_n2 += offset

    return dq

@triton.jit
def bwd_dq_compute_block_mn(
    dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
    off_z, off_hq, offs_m2, offs_n2, offs_k,
    stride_kn, stride_kd, stride_vn, stride_vd,
    kv_indices, sparse_kv_num_blocks,
    MATMUL_PRECISION, RCP_LN2,
    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, is_last_block=False,
):
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    PRESCALE_QK : tl.constexpr = False
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.08944271909999159
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = False
    BLOCK_DMODEL : tl.constexpr = 126
    BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128

    if IS_DIVISIBLE:
        kT = tl.load(kT_ptrs)
    else:
        kT = tl.load(kT_ptrs, mask=(offs_n2[None, :] < KV_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
    qk = tl.dot(q, kT)
    if not PRESCALE_QK:
        qk *= SM_SCALE
    # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
    pre_mod_scores = qk
    if is_last_block:
        m = offs_m2[:, None] % Q_LEN
        n = offs_n2[None, :] % KV_LEN
    else:
        m = offs_m2[:, None]
        n = offs_n2[None, :]
    tmp0 = (qk).to(tl.float32)
    tmp1 = (m) - (n)
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tl.full([1], 1, tl.int32)
    tmp4 = (off_hq) + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = 8.0
    tmp7 = tmp5 * tmp6
    tmp8 = 0.125
    tmp9 = tmp7 * tmp8
    tmp10 = -tmp9
    tmp11 = libdevice.exp2(tmp10)
    tmp12 = tmp2 * tmp11
    tmp13 = tmp0 + tmp12
    post_mod_scores = tmp13

    if is_last_block:
        # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
        post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))

    if not IS_FULL_BLOCKS:
        tmp14 = tl.full([1], True, tl.int1)
        mask_mod_output = tmp14

        if is_last_block:
            mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
        # apply mask for partial masked block
        post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if not PRESCALE_QK:
        post_mod_scores *= RCP_LN2
    p = tl.math.exp2(post_mod_scores - lse)
    # Compute dP and dS.
    if IS_DIVISIBLE:
        vT = tl.load(vT_ptrs)
    else:
        vT = tl.load(vT_ptrs, mask=(offs_n2[None, :] < KV_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
    dp = tl.dot(do, vT)
    ds = p * (dp - Di[:, None])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    grad_scores = (ds)

    if is_last_block:
        grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)

    ds = grad_scores

    if not IS_FULL_BLOCKS:
        if is_last_block:
            mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
        # (grads) apply mask for partially unmasked block
        ds = tl.where(mask_mod_output, ds, 0.0)
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    ds = ds.to(MATMUL_PRECISION)
    # Compute dQ.
    dq += tl.dot(ds, tl.trans(kT))

    return dq

@triton.jit
def bwd_dkdv_inner(
    Q, DO, DELTA, LSE, # pointers
    dk, dv, k, v,
    off_z, off_hq, offs_n1, offs_m1,
    stride_qm, stride_qd, stride_dom, stride_dod,
    q_indices, sparse_q_num_blocks,
    MATMUL_PRECISION,
    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
):
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    PRESCALE_QK : tl.constexpr = False
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.08944271909999159
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = False
    BLOCK_DMODEL : tl.constexpr = 126
    BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128

    SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
    RCP_LN2: tl.constexpr = 1.44269504
    Q_LEN = 128
    KV_LEN = 128

    offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)

    qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
    do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_k[None, :] * stride_dod
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    hi = sparse_q_num_blocks * SPARSE_Q_MULTIPLE

    if not IS_DIVISIBLE:
        if hi >= 1:
            for start_m in range(0, hi - 1):
                dk, dv = bwd_dkdv_compute_block_mn(
                    dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                    off_z, off_hq, offs_n1, offs_m1, offs_k,
                    stride_qm, stride_qd, stride_dom, stride_dod,
                    q_indices, sparse_q_num_blocks,
                    MATMUL_PRECISION, RCP_LN2,
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
                )
                # Increment pointers.
                offset = get_offset_for_next_block(
                    start_m, q_indices, sparse_q_num_blocks,
                    SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
                )

                qT_ptrs += offset * stride_qm
                do_ptrs += offset * stride_dom

                offs_m1 += offset

            dk, dv = bwd_dkdv_compute_block_mn(
                dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                off_z, off_hq, offs_n1, offs_m1, offs_k,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, True,
            )
    else:
        for start_m in range(0, hi):
            dk, dv = bwd_dkdv_compute_block_mn(
                dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
                off_z, off_hq, offs_n1, offs_m1, offs_k,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION, RCP_LN2,
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
            )
            # Increment pointers.
            offset = get_offset_for_next_block(
                start_m, q_indices, sparse_q_num_blocks,
                SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
            )

            qT_ptrs += offset * stride_qm
            do_ptrs += offset * stride_dom

            offs_m1 += offset

    return dk, dv

@triton.jit
def bwd_dkdv_compute_block_mn(
    dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
    off_z, off_hq, offs_n1, offs_m1, offs_k,
    stride_qm, stride_qd, stride_dom, stride_dod,
    q_indices, sparse_q_num_blocks,
    MATMUL_PRECISION, RCP_LN2,
    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, is_last_block=False,
):
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    PRESCALE_QK : tl.constexpr = False
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.08944271909999159
    GQA_SHARED_HEADS : tl.constexpr = 1
    HAS_FULL_BLOCKS : tl.constexpr = False
    BLOCK_DMODEL : tl.constexpr = 126
    BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128

    # Load LSE before computing qk to reduce pipeline stall.
    if IS_DIVISIBLE:
        qT = tl.load(qT_ptrs)
        lse = tl.load(LSE + offs_m1)
    else:
        qT = tl.load(qT_ptrs, mask=(offs_m1[None, :] < Q_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
        lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
    qkT = tl.dot(k, qT)
    if not PRESCALE_QK:
        qkT *= SM_SCALE
    # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
    if is_last_block:
        m = offs_m1[None, :] % Q_LEN
        n = offs_n1[:, None] % KV_LEN
    else:
        m = offs_m1[None, :]
        n = offs_n1[:, None]
    pre_mod_scores = qkT
    tmp15 = (qkT).to(tl.float32)
    tmp16 = (m) - (n)
    tmp17 = tmp16.to(tl.float32)
    tmp18 = tl.full([1], 1, tl.int32)
    tmp19 = (off_hq) + tmp18
    tmp20 = tmp19.to(tl.float32)
    tmp21 = 8.0
    tmp22 = tmp20 * tmp21
    tmp23 = 0.125
    tmp24 = tmp22 * tmp23
    tmp25 = -tmp24
    tmp26 = libdevice.exp2(tmp25)
    tmp27 = tmp17 * tmp26
    tmp28 = tmp15 + tmp27
    post_mod_scores = tmp28

    if is_last_block:
        # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
        post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))

    if not IS_FULL_BLOCKS:
        tmp29 = tl.full([1], True, tl.int1)
        mask_mod_output = tmp29

        if is_last_block:
            mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
        # (grads) apply mask for fully masked block
        post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if not PRESCALE_QK:
        post_mod_scores *= RCP_LN2
    pT = tl.math.exp2(post_mod_scores - lse[None, :])
    if IS_DIVISIBLE:
        do = tl.load(do_ptrs)
    else:
        do = tl.load(do_ptrs, mask=(offs_m1[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
    # Compute dV.
    ppT = pT
    dv += tl.dot(ppT.to(MATMUL_PRECISION), do)
    if IS_DIVISIBLE:
        Di = tl.load(DELTA + offs_m1)
    else:
        Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
    # Compute dP and dS.
    dpT = tl.dot(v, tl.trans(do))
    dsT = pT * (dpT - Di[None, :])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    grad_scores = (dsT)

    if is_last_block:
        grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)

    dsT = grad_scores
    if not IS_FULL_BLOCKS:
        if is_last_block:
            mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
        # (grads) apply mask for partially unmasked block
        dsT = tl.where(mask_mod_output, dsT, 0.0)
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))

    return dk, dv

@triton.jit
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
    cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
    cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
    next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
    needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
    jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK

    offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
    return offset

import torch._inductor.kernel.flex_attention
meta0 = {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True, 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08944271909999159, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': False, 'BLOCK_DMODEL': 126, 'BLOCK_DMODEL_ROUNDED': 128, 'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}

def call(args):
    primals_1, primals_2, primals_3, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1 = args
    args.clear()
    assert_size_stride(primals_1, (4, 8, 128, 126), (128000, 16128, 126, 1))
    assert_size_stride(primals_2, (4, 8, 128, 126), (128000, 16128, 126, 1))
    assert_size_stride(primals_3, (4, 8, 128, 126), (128000, 16128, 126, 1))
    assert_size_stride(full, (1, 1, 1), (1, 1, 1))
    assert_size_stride(full_default, (1, 1, 1, 1), (1, 1, 1, 1))
    assert_size_stride(convert_element_type, (1, 1, 1), (1, 1, 1))
    assert_size_stride(convert_element_type_1, (1, 1, 1, 1), (1, 1, 1, 1))
    assert_size_stride(getitem_2, (4, 8, 128, 126), (128000, 16128, 126, 1))
    assert_size_stride(getitem_3, (4, 8, 128), (1024, 128, 1))
    assert_size_stride(tangents_1, (4, 8, 128, 126), (128000, 16128, 126, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf1 = empty_strided_cuda((4, 8, 128), (1024, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
        stream0 = get_raw_stream(0)
        triton_per_fused_zeros_0[grid(4096)](getitem_2, tangents_1, buf1, 4096, 126, XBLOCK=8, num_warps=8, num_stages=1)
        del getitem_2
        buf2 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
        buf3 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
        buf4 = empty_strided_cuda((0, ), (1, ), torch.float32)
        buf5 = empty_strided_cuda((0, ), (1, ), torch.float32)
        buf6 = empty_strided_cuda((0, ), (1, ), torch.float32)
        buf7 = empty_strided_cuda((0, ), (1, ), torch.float32)
        buf8 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
        # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
        triton_tem_fused_zeros_1[torch._inductor.kernel.flex_attention.flex_attention_backward_grid(4, 8, 128, 126, 8, 128, meta0)](primals_1, primals_2, primals_3, getitem_3, buf1, tangents_1, buf2, buf3, full, full_default, convert_element_type, convert_element_type_1, buf4, buf5, buf6, buf7, buf8, num_warps=4, num_stages=2)
        del buf1
        del buf4
        del buf5
        del buf6
        del buf7
        del convert_element_type
        del convert_element_type_1
        del full
        del full_default
        del getitem_3
        del primals_1
        del primals_2
        del primals_3
        del tangents_1
    return (buf2, buf8, buf3, )

def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
    primals_2 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
    primals_3 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
    full = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
    full_default = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
    convert_element_type = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
    convert_element_type_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
    getitem_2 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
    getitem_3 = rand_strided((4, 8, 128), (1024, 128, 1), device='cuda:0', dtype=torch.float32)
    tangents_1 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
    fn = lambda: call([primals_1, primals_2, primals_3, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1])
    return print_performance(fn, times=times, repeat=repeat)

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
manman-ren commented 2 months ago

Seems to run fine with OSS Triton + latest pytorch python test.py 0.000050

manman-ren commented 2 months ago

PR4247 fixed the issue on top of the current PT2 triton pin (dedb7bdf339a3546896d4820366ca562c586bfa0).