triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

Triton flash attention for inference with len(q)=1, IndexError: map::at #1204

Open linxihui opened 1 year ago

linxihui commented 1 year ago

I modified the flash attention code to make it work for cases where q has sequence length of 1, while k, v has larger length. This is critical for text generation at run time where one token is generated at a time.

Since tl.dot does not support matrix/vector dot product. And matrix/vector product cannot utilize TensorCore anyway, so I replace it with broadcasting + tl.sum. But I run into the following error. It isn't very informative for me. Can people help?

Traceback (most recent call last):
  File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 163, in <module>
    test_op(3, 2, 2048, 64, 1)
  File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 158, in test_op
    tri_out = attention(q, k, v, sm_scale)
  File "/workspace/triton/python/tutorials/fused_attn_inference.py", line 125, in attention
    _fwd_kernel[grid](
  File "/workspace/triton/python/triton/code_gen.py", line 999, in __call__
    return self.kernel(*wargs, **kwargs, grid=self.grid)
  File "/workspace/triton/python/triton/code_gen.py", line 988, in __call__
    return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
  File "/workspace/triton/python/triton/code_gen.py", line 956, in add_to_cache
    return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages,
  File "/workspace/triton/python/triton/code_gen.py", line 1285, in _warmup
    binary = self._compile(**compile)
  File "/workspace/triton/python/triton/code_gen.py", line 1320, in _compile
    name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
IndexError: map::at

Here is my code:


import pytest
import torch

import triton
import triton.language as tl

@triton.jit
def _fwd_kernel(
    Q, K, V, sm_scale,
    TMP, L, M,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    Out,
    stride_qz, stride_qh, stride_qm, stride_qd,
    stride_kz, stride_kh, stride_kn, stride_kd,
    stride_vz, stride_vh, stride_vn, stride_vd,
    stride_oz, stride_oh, stride_om, stride_od,
    Z, H, N_CTX,
    PREVIOUS_LEN,  # aka, past_len
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    off_z = off_hz // H
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
    off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
    off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    t_ptrs = TMP + off_hz * N_CTX + offs_m
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    q = tl.load(q_ptrs)

    for start_n in range(0, (start_m + 1) * BLOCK_M + PREVIOUS_LEN, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(k_ptrs + start_n * stride_kn)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # qk += tl.dot(q, k, trans_b=True)
        qk += tl.sum(q[:, None, :]*k[None, :, :], 2)  # replace the above tl.dot
        qk *= sm_scale
        qk += tl.where(offs_m[:, None] + PREVIOUS_LEN >= (start_n + offs_n[None, :]), 0, float("-inf"))
        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        p = tl.exp(qk - m_ij[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        m_i_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp(m_i - m_i_new)
        beta = tl.exp(m_ij - m_i_new)
        l_i_new = alpha * l_i + beta * l_ij
        # -- update output accumulator --
        # scale p
        p_scale = beta / l_i_new
        p = p * p_scale[:, None]
        # scale acc
        acc_scale = l_i / l_i_new * alpha
        tl.store(t_ptrs, acc_scale)
        acc_scale = tl.load(t_ptrs)  # BUG: have to store and immediately load
        acc = acc * acc_scale[:, None]
        # update acc
        v = tl.load(v_ptrs + start_n * stride_vn)
        p = p.to(v.dtype)
        # acc += tl.dot(p, v)
        acc += tl.sum(p[:, :, None] * v[None, :, :], 1)  # replace the above tl.dot
        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new
    # rematerialize offsets to save registers
    # start_m = tl.program_id(0)
    # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    l_ptrs = L + off_hz * N_CTX + offs_m
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(l_ptrs, l_i)
    tl.store(m_ptrs, m_i)
    # initialize pointers to output
    off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc)

def attention(q, k, v, sm_scale):
    BLOCK_M = 1
    BLOCK_N = 128
    # shape constraints
    # Lq, Lk = q.shape[-1], k.shape[-1]
    # assert Lq == Lk
    # o = torch.empty_like(q)
    # grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
    # tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    # L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    # m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

    Lq, Lk = q.shape[-1], k.shape[-1]
    assert Lq == Lk
    o = torch.empty_like(q)
    grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
    tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

    BLOCK_DMODEL = q.shape[-1]
    _fwd_kernel[grid](
        q, k, v, sm_scale,
        tmp, L, m,
        o,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        q.shape[0], q.shape[1], k.shape[2],
        k.shape[2] - q.shape[2],
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        BLOCK_DMODEL=BLOCK_DMODEL, 
        num_warps=4,
        num_stages=1,
    )

    return o

@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, Q_LEN', [(3, 2, 2048, 64, 1)])
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN, dtype=torch.float16):
    torch.manual_seed(20)
    q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
    sm_scale = 0.3
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))[-Q_LEN:]
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    p = p.masked_fill_((1-M).bool(), float("-inf"))
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    # triton implementation
    tri_out = attention(q, k, v, sm_scale)
    # compare
    triton.testing.assert_almost_equal(ref_out, tri_out)

test_op(3, 2, 2048, 64, 1)

BTW, I am using A100 with cuda11.6.

resorcap commented 1 year ago

I encountered the same problem. Promblem solved?

manman-ren commented 11 months ago

I can't repro this issue with Triton 2.1.0 (OSS hash from Nov 15).