pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
358 stars 14 forks source link

Two errors: (1) NameError: ModularIndexing is not defined & (2) LoweringException: AttributeError: 'View' object has no attribute 'get_stride' #45

Open tobiasvanderwerff opened 11 hours ago

tobiasvanderwerff commented 11 hours ago

The following code leads to an error:

import torch
from torch.nn.attention.flex_attention import flex_attention

B, H, N, D = 100, 12, 128, 64
dtype = torch.bfloat16
device = torch.device("cuda")

class Attention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = torch.randn(B, N, N, H, device=device, dtype=dtype)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        score_mod = generate_score_mod(self.bias)
        o = flex_attention(q, k, v, score_mod=score_mod)
        return o

def generate_score_mod(bias):
    bias = (2 * bias).view(B, H, N, N)

    def score_mod(score, batch, head, q_idx, k_idx):
        attn_bias = bias[batch, head, q_idx, k_idx]
        return score + attn_bias

    return score_mod

if __name__ == "__main__":
    m = Attention().cuda().eval().to(dtype)
    m = torch.compile(m, mode='default', fullgraph=False)
    # m = torch.compile(m, mode='max-autotune', fullgraph=False)  # this also fails

    q = torch.randn(B, H, N, D, device=device, dtype=dtype)
    k = torch.randn(B, H, N, D, device=device, dtype=dtype)
    v = torch.randn(B, H, N, D, device=device, dtype=dtype)

    m(q, k, v)

The error depends on the torch.compile mode I'm using.

If using torch.compile(..., mode='default', ...), I get the following error:

E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] Triton compilation failed: triton_tem_fused_2
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] def triton_(arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ROWS_GUARANTEED_SAFE : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     PRESCALE_QK : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     OUTPUT_LOGSUMEXP : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FLOAT32_PRECISION : tl.constexpr = 'ieee'
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     IS_DIVISIBLE : tl.constexpr = True
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SM_SCALE : tl.constexpr = 0.125
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     GQA_SHARED_HEADS : tl.constexpr = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     HAS_FULL_BLOCKS : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     QK_HEAD_DIM : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_HEAD_DIM : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     BLOCK_M : tl.constexpr = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     BLOCK_N : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q = arg_Q
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K = arg_K
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V = arg_V
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     LSE = arg_LSE
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_NUM_BLKS = arg_KV_NUM_BLKS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_IDX = arg_KV_IDX
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FULL_KV_IDX = arg_FULL_KV_IDX
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Sub notation for this kernel:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Q: Query, K: Key, V: Value
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # M: Number of queries, N: Number of keys/values, D: Model dimension
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # QK_HEAD_DIM: The dimension of the query and key embeddings
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # V_HEAD_DIM: The dimension of the value embeddings
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # (Modifiable) Performance tuning options
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # BLOCK_M: The thread block size across the seqlen dim of Q.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # The below are kernel options that can be applied for certain score_mods,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # or involve a numerics vs. perf tradeoff
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # about 20% more numerical error, but slightly faster.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # is not masked out? If so, we can skip an extra safety check
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Define strides of inputs
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_qz, stride_qh, stride_qm, stride_qk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kz, stride_kh, stride_kn, stride_kk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_vz, stride_vh, stride_vn, stride_vk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ZQ = 100
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     HQ = 12
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q_LEN = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ZKV = 100
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_LEN = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     MATMUL_PRECISION = Q.dtype.element_ty
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     q_start = tl.program_id(0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_zq = tl.program_id(1) // HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_hq = tl.program_id(1) % HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_zkv = off_zq % ZKV
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_hkv = off_hq // GQA_SHARED_HEADS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_g = off_hq % GQA_SHARED_HEADS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     q_offset = off_zq * stride_qz + off_hq * stride_qh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     k_offset = off_zkv * stride_kz + off_hkv * stride_kh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     v_offset = off_zkv * stride_vz + off_hkv * stride_vh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q = Q + q_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K = K + k_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V = V + v_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Z = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_HQ = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_idx_z = off_zq % SPARSE_Z
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_idx_hq = off_hq % SPARSE_HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_num_blks_h = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_idx_h = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_idx_m = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # initialize pointer to m and l
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_IDX and KV_NUM_BLKS are always contiguous.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m  # noqa: B950
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=Q,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(Q_LEN, QK_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_qm, stride_qk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(q_start * BLOCK_M, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_M, QK_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # load q: it stays in SRAM throughout the inner loop.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = tl.load(Q_block_ptr)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # boundary check is not free, so we only do it when necessary.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We don't know anything "special" about these blocks, so we need to apply
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # both score_mod and mask_mod to it
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_indices = KV_IDX + sparse_kv_idx_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=K,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(QK_HEAD_DIM, KV_LEN),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_kk, stride_kn),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(0, kv_start),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(QK_HEAD_DIM, BLOCK_N),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(0, 1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         acc, l_i, m_i,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         off_zq, off_hq, offs_m[:, None], offs_n[None, :],
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_indices, kv_num_blocks,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         0, block_n_end,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         MATMUL_PRECISION,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         IS_FULL_BLOCKS=False,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We know these blocks are guaranteed to be "full", so we don't need to
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # apply mask_mod to them - only score_mod
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if HAS_FULL_BLOCKS:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         K_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             base=K,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             shape=(QK_HEAD_DIM, KV_LEN),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             strides=(stride_kk, stride_kn),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             offsets=(0, kv_start),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             block_shape=(QK_HEAD_DIM, BLOCK_N),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             order=(0, 1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             acc, l_i, m_i,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             off_zq, off_hq, offs_m[:, None], offs_n[None, :],
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             kv_indices, kv_num_blocks,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             0, block_n_end,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             MATMUL_PRECISION,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             IS_FULL_BLOCKS=True,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # [Note] Handle fully masked out rows:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     l_i = tl.where(l_i == 0.0, 1, l_i)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc = acc / l_i[:, None]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_zq = tl.program_id(1) // HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_hq = tl.program_id(1) % HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_m = offs_m[:, None]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     mask = idx_m < Q_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # TODO generalize and add proper mask support
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     xindex = idx_d + (64*idx_m) + (8192*idx_hq) + (98304*idx_zq)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.store(out_ptr0 + (tl.broadcast_to(xindex, acc.shape)), acc, mask)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # TODO dont want to write this if we dont require grad
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if OUTPUT_LOGSUMEXP:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         off_hz = tl.program_id(1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         l_ptrs = LSE + off_hz * Q_LEN + offs_m
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         lse = m_i + tl.math.log2(l_i)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             tl.store(l_ptrs, lse)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] metadata: {'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*fp32', 4: '*i32', 5: '*i32', 6: '*fp32', 7: '*fp32', 8: '*bf16', 9: '*bf16'}, 'device': 0, 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 3, 'debug': True, 'cc': 80}
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 49:29:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if CHECK_BLOCK_BOUNDARY:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # which is larger than the actual number of elements. To avoid access memory out of bound,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # we need to mask out the elements that are out of Q_LEN & KV_LEN.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         m = offs_m % Q_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         n = offs_n % KV_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         m = offs_m
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         n = offs_n
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tmp0 = 2.0
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tmp1 = tl.load(in_ptr8 + ModularIndexing(128*(m) + (n) + 16384*(off_h), 1, 12) + 12*ModularIndexing(128*(m) + (n) + 16384*(off_h), 12, 128) + 1536*ModularIndexing(128*(m) + (n) + 16384*(off_h), 1536, 128) + 196608*ModularIndexing(128*(m) + (n) + 16384*(off_h) + 196608*(off_z), 196608, 100)) * tmp0
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                              ^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] NameError('ModularIndexing is not defined')
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] The above exception was the direct cause of the following exception:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 42:28:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     RCP_LN2: tl.constexpr = 1.44269504
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if PRESCALE_QK:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # loop over k, v and update accumulator until block_n_end
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     for start_n in range(block_n_start, block_n_end):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             acc, l_i, m_i = forward_block_mn(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                             ^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] The above exception was the direct cause of the following exception:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] Traceback (most recent call last):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 444, in _precompile_config
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     binary = triton.compile(*compile_args, **compile_kwargs)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     module = src.make_ir(options, codegen_fns, context)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 154:20:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                     ^
Traceback (most recent call last):
  File "/home/azureuser/a.py", line 36, in <module>
    m(q, k, v)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1533, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1359, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1430, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 882, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1948, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1874, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1902, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2949, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_azureuser/ab/cabknymdngk3mgmkmh4dqxr6zjqfhqqb3uqpssjql5l3uxbazccb.py", line 106, in <module>
    triton_tem_fused_2 = async_compile.triton('triton_', '''
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 213, in triton
    kernel.precompile()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 444, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 154:20:
    )
    V_block_ptr = tl.make_block_ptr(
        base=V,
        shape=(KV_LEN, V_HEAD_DIM),
        strides=(stride_vn, stride_vk),
        offsets=(kv_start, 0),
        block_shape=(BLOCK_N, V_HEAD_DIM),
        order=(1, 0)
    )
    offs_n = kv_start + tl.arange(0, BLOCK_N)

    acc, l_i, m_i = forward_inner(
                    ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Notably, the error goes away if I move the following line in generate_score_mod to __init__ instead:

    bias = (2 * bias).view(B, H, N, N)

Relevant specs:

tobiasvanderwerff commented 11 hours ago

If using torch.compile(..., mode='max-autotune', ...), I get a different error (also resolved by the fix above):

Traceback (most recent call last):
  File "/home/azureuser/a.py", line 36, in <module>
    m(q, k, v)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1272, in compile_fx
    return compile_fx(
           ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1533, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1359, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1430, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 863, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1357, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1023, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1020, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 913, in flex_attention
    autotune_select_algorithm(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1729, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1224, in __call__
    inputs_key = create_inputs_key(input_nodes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in create_inputs_key
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in <listcomp>
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1698, in key_of
    node.get_stride(),
    ^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6276, in __getattr__
    fn = getattr(self.data, name)
         ^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'get_stride'
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    ComputedBuffer(name='buf2', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.constant(1, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1],
      origin_node=full,
      origins=OrderedSet([full])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf3', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.constant(0, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1, 1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    ))
  )), None, None, TensorBox(StorageBox(
    ComputedBuffer(name='buf4', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.load(buf0, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1],
      origin_node=convert_element_type,
      origins=OrderedSet([sum_1, convert_element_type])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf5', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.index_expr(0, dtype=torch.int16)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1, 1],
      origin_node=convert_element_type_1,
      origins=OrderedSet([convert_element_type_1, sort])
    ))
  )), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
  args[7]: (TensorBox(
    View(
      StorageBox(
        Pointwise(
          'cuda',
          torch.bfloat16,
          def inner_fn(index):
              i0, i1, i2, i3 = index
              tmp0 = ops.load(arg0_1, i3 + 12 * i2 + 1536 * i1 + 196608 * i0)
              tmp1 = ops.constant(2, torch.bfloat16)
              tmp2 = tmp0 * tmp1
              return tmp2
          ,
          ranges=[100, 128, 128, 12],
          origin_node=mul,
          origins=OrderedSet([mul])
        )
      ),
      size=[100, 12, 128, 128],
      reindex=lambda i0, i1, i2, i3: [ModularIndexing(196608*i0 + 16384*i1 + 128*i2 + i3, 196608, 100), ModularIndexing(16384*i1 + 128*i2 + i3, 1536, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 12, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 1, 12)],
      origins=OrderedSet([view, mul])
    )
  ),)
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True