I have a PR to enable non power of 2 head_dim for FlexAttention. This impl does this via padding up to the next power of 2 and masking the loads and stores. https://github.com/pytorch/pytorch/pull/133495
I am encountering a segfault only when the head_dim is even, when it is odd this strategy works correctly. This is on an H100 machine.
lldb stacktrace:
* thread #1, name = 'pt_main_thread', stop reason = signal SIGSEGV: address not mapped to object (fault address: 0x10)
frame #0: 0x00007ffed367fbfe libtriton.so`scheduleRemainingToLastStage(forOp=ForOp @ 0x00007fffffffbcf8, schedule=0x00007fffffffc080, afterPrologue=<unavailable>, numStages=2) at MatmulLoopPipeline.cpp:893:9
(lldb) bt
* thread #1, name = 'pt_main_thread', stop reason = signal SIGSEGV: address not mapped to object (fault address: 0x10)
* frame #0: 0x00007ffed367fbfe libtriton.so`scheduleRemainingToLastStage(forOp=ForOp @ 0x00007fffffffbcf8, schedule=0x00007fffffffc080, afterPrologue=<unavailable>, numStages=2) at MatmulLoopPipeline.cpp:893:9
frame #1: 0x00007ffed368d970 libtriton.so`mlir::triton::preProcessLoopAndGetSchedule(forOp=0x00007fffffffc460, numStages=2, options=0x00007fffffffc520) at MatmulLoopPipeline.cpp:1230:31
frame #2: 0x00007ffed36a6a43 libtriton.so`mlir::triton::gpu::PipelinePass::runOnOperation() [inlined] pipelineLoop(numStages=2, forOp=ForOp @ 0x00007fffffffc460) at SoftwarePipeliner.cpp:79:47
frame #3: 0x00007ffed36a6998 libtriton.so`mlir::triton::gpu::PipelinePass::runOnOperation(this=0x000000000a165e60) at SoftwarePipeliner.cpp:125:36
frame #4: 0x00007ffed3c5147c libtriton.so`mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 700
frame #5: 0x00007ffed3c51df2 libtriton.so`mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 354
frame #6: 0x00007ffed3c5481c libtriton.so`mlir::PassManager::run(mlir::Operation*) + 876
frame #7: 0x00007ffed3942bad libtriton.so`<lambda(mlir::PassManager&, mlir::ModuleOp&)>::operator(self=<unavailable>, mod=0x000000000ad8d920, __closure=<unavailable>)(mlir::PassManager &, mlir::ModuleOp &) at ir.cc:1625:19
frame #8: 0x00007ffed3960108 libtriton.so`_FUN [inlined] operator(this=0x0000000000000000, call=0x00007fffffffcd80) at cast.h:1480:37
frame #9: 0x00007ffed39600f0 libtriton.so`_FUN((null)=0x00007fffffffcd80) at pybind11.h:224:21
frame #10: 0x00007ffed9ee5590 libtriton.so`typeinfo for pybind11::handle + 24
frame #11: 0x00007ffed9ee5590 libtriton.so`typeinfo for pybind11::handle + 24
Repro, on pytorch Nightly:
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
# kernel path: /tmp/torchinductor_drisspg/76/c76vxhoeyr34oca27x7ybxq6orqpjucbsyxyxgtfkdyuibultz5r.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_4 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 8, 128], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph, %joint_graph, (%full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 128, 128, %mask_graph), 0.08944271909999159, {ROWS_GUARANTEED_SAFE: False, PRESCALE_QK: False, OUTPUT_LOGSUMEXP: True, IS_DIVISIBLE: True}, (), ()), kwargs = {})
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton.jit
def triton_per_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 4096
rnumel = 126
RBLOCK: tl.constexpr = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (126*x0)), rmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (126*x0)), rmask, other=0.0).to(tl.float32)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = tl.where(rmask, tmp3, 0)
tmp6 = tl.sum(tmp5, 1)[:, None]
tmp7 = tmp6.to(tl.float32)
tmp8 = 0.0
tmp9 = tmp7 - tmp8
tl.store(out_ptr1 + (x0), tmp9, None)
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_drisspg/ho/choddawatwsj2tif3t6ixv5pzlpm4fxw3mn35w6ytgd24xjg6mrj.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_4 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([4, 8, 128], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem_2, %getitem_3, %tangents_1, %full_default_4, %fw_graph, %joint_graph, (%full, %full_default, None, None, %convert_element_type, %convert_element_type_1, None, None, 128, 128, %mask_graph), 0.08944271909999159, {ROWS_GUARANTEED_SAFE: False, PRESCALE_QK: False, OUTPUT_LOGSUMEXP: True, IS_DIVISIBLE: True}, (), ()), kwargs = {})
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton.jit
def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
ROWS_GUARANTEED_SAFE : tl.constexpr = False
PRESCALE_QK : tl.constexpr = False
OUTPUT_LOGSUMEXP : tl.constexpr = True
IS_DIVISIBLE : tl.constexpr = False
SM_SCALE : tl.constexpr = 0.08944271909999159
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
BLOCK_DMODEL : tl.constexpr = 126
BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 64
BLOCK_M2 : tl.constexpr = 64
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
DELTA = arg_DELTA
DO = arg_DO
DQ = arg_DQ
DV = arg_DV
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
Q_NUM_BLKS = arg_Q_NUM_BLKS
Q_IDX = arg_Q_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
FULL_Q_IDX = arg_FULL_Q_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values, D: Model dimension
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = 128000, 16128, 126, 1
stride_kz, stride_kh, stride_kn, stride_kd = 128000, 16128, 126, 1
stride_vz, stride_vh, stride_vn, stride_vd = 128000, 16128, 126, 1
stride_doz, stride_doh, stride_dom, stride_dod = 128000, 16128, 126, 1
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 128000, 16128, 126, 1
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 128000, 16128, 126, 1
Z = 4
HQ = 8
HKV = 8
Q_LEN = 128
KV_LEN = 128
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_z = off_hz // HKV # batch idx
off_hkv = off_hz % HKV # kv head idx
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_z % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
if IS_DIVISIBLE:
tl.static_assert(BLOCK_DMODEL == BLOCK_DMODEL_ROUNDED)
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = 1
stride_kv_idx_h = 1
stride_kv_idx_m = 1
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
if IS_DIVISIBLE:
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)
else:
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
K, V,
dq, q, do, Di, lse,
off_z, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
IS_FULL_BLOCKS=False
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
K, V,
dq, q, do, Di, lse,
off_z, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
IS_FULL_BLOCKS=True
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = 1
stride_q_idx_h = 1
stride_q_idx_n = 1
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
if IS_DIVISIBLE:
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd)
else:
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=(offs_n1[:, None] < KV_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd, mask=(offs_n1[:, None] < KV_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_z, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
IS_FULL_BLOCKS=False
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_z, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
IS_FULL_BLOCKS=True
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
if IS_DIVISIBLE:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_k < BLOCK_DMODEL))
dk *= SM_SCALE
mask = (index_n < KV_LEN) & (index_k < BLOCK_DMODEL)
xindex = index_k + (126*index_n) + (16128*off_hkv) + (128000*off_z)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
@triton.jit
def bwd_dq_inner(
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
):
ROWS_GUARANTEED_SAFE : tl.constexpr = False
PRESCALE_QK : tl.constexpr = False
OUTPUT_LOGSUMEXP : tl.constexpr = True
IS_DIVISIBLE : tl.constexpr = False
SM_SCALE : tl.constexpr = 0.08944271909999159
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
BLOCK_DMODEL : tl.constexpr = 126
BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 64
BLOCK_M2 : tl.constexpr = 64
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 128
KV_LEN = 128
offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_k[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_compute_block_mn(
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_compute_block_mn(
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_compute_block_mn(
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_compute_block_mn(
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, is_last_block=False,
):
ROWS_GUARANTEED_SAFE : tl.constexpr = False
PRESCALE_QK : tl.constexpr = False
OUTPUT_LOGSUMEXP : tl.constexpr = True
IS_DIVISIBLE : tl.constexpr = False
SM_SCALE : tl.constexpr = 0.08944271909999159
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
BLOCK_DMODEL : tl.constexpr = 126
BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 64
BLOCK_M2 : tl.constexpr = 64
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
if IS_DIVISIBLE:
kT = tl.load(kT_ptrs)
else:
kT = tl.load(kT_ptrs, mask=(offs_n2[None, :] < KV_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
qk = tl.dot(q, kT)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
if is_last_block:
m = offs_m2[:, None] % Q_LEN
n = offs_n2[None, :] % KV_LEN
else:
m = offs_m2[:, None]
n = offs_n2[None, :]
tmp0 = (qk).to(tl.float32)
tmp1 = (m) - (n)
tmp2 = tmp1.to(tl.float32)
tmp3 = tl.full([1], 1, tl.int32)
tmp4 = (off_hq) + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = 8.0
tmp7 = tmp5 * tmp6
tmp8 = 0.125
tmp9 = tmp7 * tmp8
tmp10 = -tmp9
tmp11 = libdevice.exp2(tmp10)
tmp12 = tmp2 * tmp11
tmp13 = tmp0 + tmp12
post_mod_scores = tmp13
if is_last_block:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp14 = tl.full([1], True, tl.int1)
mask_mod_output = tmp14
if is_last_block:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
if IS_DIVISIBLE:
vT = tl.load(vT_ptrs)
else:
vT = tl.load(vT_ptrs, mask=(offs_n2[None, :] < KV_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
dp = tl.dot(do, vT)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
grad_scores = (ds)
if is_last_block:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
ds = grad_scores
if not IS_FULL_BLOCKS:
if is_last_block:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT))
return dq
@triton.jit
def bwd_dkdv_inner(
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
):
ROWS_GUARANTEED_SAFE : tl.constexpr = False
PRESCALE_QK : tl.constexpr = False
OUTPUT_LOGSUMEXP : tl.constexpr = True
IS_DIVISIBLE : tl.constexpr = False
SM_SCALE : tl.constexpr = 0.08944271909999159
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
BLOCK_DMODEL : tl.constexpr = 126
BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 64
BLOCK_M2 : tl.constexpr = 64
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 128
KV_LEN = 128
offs_k = tl.arange(0, BLOCK_DMODEL_ROUNDED)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_k[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = sparse_q_num_blocks * SPARSE_Q_MULTIPLE
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_compute_block_mn(
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_compute_block_mn(
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_compute_block_mn(
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_compute_block_mn(
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, IS_FULL_BLOCKS, is_last_block=False,
):
ROWS_GUARANTEED_SAFE : tl.constexpr = False
PRESCALE_QK : tl.constexpr = False
OUTPUT_LOGSUMEXP : tl.constexpr = True
IS_DIVISIBLE : tl.constexpr = False
SM_SCALE : tl.constexpr = 0.08944271909999159
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
BLOCK_DMODEL : tl.constexpr = 126
BLOCK_DMODEL_ROUNDED : tl.constexpr = 128
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 64
BLOCK_M2 : tl.constexpr = 64
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
qT = tl.load(qT_ptrs)
lse = tl.load(LSE + offs_m1)
else:
qT = tl.load(qT_ptrs, mask=(offs_m1[None, :] < Q_LEN) & (offs_k[:, None] < BLOCK_DMODEL))
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
qkT = tl.dot(k, qT)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
if is_last_block:
m = offs_m1[None, :] % Q_LEN
n = offs_n1[:, None] % KV_LEN
else:
m = offs_m1[None, :]
n = offs_n1[:, None]
pre_mod_scores = qkT
tmp15 = (qkT).to(tl.float32)
tmp16 = (m) - (n)
tmp17 = tmp16.to(tl.float32)
tmp18 = tl.full([1], 1, tl.int32)
tmp19 = (off_hq) + tmp18
tmp20 = tmp19.to(tl.float32)
tmp21 = 8.0
tmp22 = tmp20 * tmp21
tmp23 = 0.125
tmp24 = tmp22 * tmp23
tmp25 = -tmp24
tmp26 = libdevice.exp2(tmp25)
tmp27 = tmp17 * tmp26
tmp28 = tmp15 + tmp27
post_mod_scores = tmp28
if is_last_block:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp29 = tl.full([1], True, tl.int1)
mask_mod_output = tmp29
if is_last_block:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
if IS_DIVISIBLE:
do = tl.load(do_ptrs)
else:
do = tl.load(do_ptrs, mask=(offs_m1[:, None] < Q_LEN) & (offs_k[None, :] < BLOCK_DMODEL))
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
grad_scores = (dsT)
if is_last_block:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if is_last_block:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
return dk, dv
@triton.jit
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
import torch._inductor.kernel.flex_attention
meta0 = {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True, 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08944271909999159, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': False, 'BLOCK_DMODEL': 126, 'BLOCK_DMODEL_ROUNDED': 128, 'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}
def call(args):
primals_1, primals_2, primals_3, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1 = args
args.clear()
assert_size_stride(primals_1, (4, 8, 128, 126), (128000, 16128, 126, 1))
assert_size_stride(primals_2, (4, 8, 128, 126), (128000, 16128, 126, 1))
assert_size_stride(primals_3, (4, 8, 128, 126), (128000, 16128, 126, 1))
assert_size_stride(full, (1, 1, 1), (1, 1, 1))
assert_size_stride(full_default, (1, 1, 1, 1), (1, 1, 1, 1))
assert_size_stride(convert_element_type, (1, 1, 1), (1, 1, 1))
assert_size_stride(convert_element_type_1, (1, 1, 1, 1), (1, 1, 1, 1))
assert_size_stride(getitem_2, (4, 8, 128, 126), (128000, 16128, 126, 1))
assert_size_stride(getitem_3, (4, 8, 128), (1024, 128, 1))
assert_size_stride(tangents_1, (4, 8, 128, 126), (128000, 16128, 126, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((4, 8, 128), (1024, 128, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_0[grid(4096)](getitem_2, tangents_1, buf1, 4096, 126, XBLOCK=8, num_warps=8, num_stages=1)
del getitem_2
buf2 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
buf3 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
buf4 = empty_strided_cuda((0, ), (1, ), torch.float32)
buf5 = empty_strided_cuda((0, ), (1, ), torch.float32)
buf6 = empty_strided_cuda((0, ), (1, ), torch.float32)
buf7 = empty_strided_cuda((0, ), (1, ), torch.float32)
buf8 = empty_strided_cuda((4, 8, 128, 126), (128000, 16128, 126, 1), torch.float16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
triton_tem_fused_zeros_1[torch._inductor.kernel.flex_attention.flex_attention_backward_grid(4, 8, 128, 126, 8, 128, meta0)](primals_1, primals_2, primals_3, getitem_3, buf1, tangents_1, buf2, buf3, full, full_default, convert_element_type, convert_element_type_1, buf4, buf5, buf6, buf7, buf8, num_warps=4, num_stages=2)
del buf1
del buf4
del buf5
del buf6
del buf7
del convert_element_type
del convert_element_type_1
del full
del full_default
del getitem_3
del primals_1
del primals_2
del primals_3
del tangents_1
return (buf2, buf8, buf3, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
primals_2 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
primals_3 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
full = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
full_default = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
convert_element_type = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32)
convert_element_type_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32)
getitem_2 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
getitem_3 = rand_strided((4, 8, 128), (1024, 128, 1), device='cuda:0', dtype=torch.float32)
tangents_1 = rand_strided((4, 8, 128, 126), (128000, 16128, 126, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([primals_1, primals_2, primals_3, full, full_default, convert_element_type, convert_element_type_1, getitem_2, getitem_3, tangents_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Summary
I have a PR to enable non power of 2 head_dim for FlexAttention. This impl does this via padding up to the next power of 2 and masking the loads and stores. https://github.com/pytorch/pytorch/pull/133495
I am encountering a segfault only when the head_dim is even, when it is odd this strategy works correctly. This is on an H100 machine. lldb stacktrace:
Repro, on pytorch Nightly: