triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.7k stars 1.53k forks source link

BUG in flash attention kernel #2626

Open zhanglei1172 opened 10 months ago

zhanglei1172 commented 10 months ago
import math

import pytest
import torch
import triton
import triton.language as tl

# from triton.runtime.interpreter import TensorHandle

@triton.jit
def _fwd_kernel2(
    Q,
    K,
    V,
    # POS_EMB2,
    sm_scale,
    Out,
    # stride_e0,
    # stride_e1,
    stride_qh,
    stride_qm,
    stride_qk,
    stride_kh,
    stride_kn,
    stride_kk,
    stride_vh,
    stride_vk,
    stride_vn,
    stride_oh,
    stride_om,
    stride_on,
    N_CTX,
    N_CTX_2,
    HIDDEN_DIM,
    # emb_len,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qvk_offset = off_hz * stride_qh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HIDDEN_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HIDDEN_DIM, N_CTX_2),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1),
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX_2, HIDDEN_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0),
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    # l_i = tl.where(offs_m < N_CTX, l_i, 1)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # credits to: Adam P. Goucher (https://github.com/apgoucher):
    # scale sm_scale by 1/log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr, padding_option="zero", boundary_check=(0, 1))
    q = (q * qk_scale).to(K.dtype.element_ty)
    lo = 0
    hi = N_CTX_2
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(K_block_ptr, padding_option="zero", boundary_check=(0, 1))
        v = tl.load(V_block_ptr, padding_option="zero", boundary_check=(0, 1))
        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # if IS_CAUSAL:
        #     qk = tl.where(
        #         offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")
        #     )
        qk += tl.dot(q, k, allow_tf32=True)
        # -- compute scaling constant ---
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        alpha = tl.math.exp2(m_i - m_i_new)
        p = tl.math.exp2(qk - m_i_new[:, None])
        # -- scale and update acc --
        acc_scale = l_i * 0 + alpha  # workaround some compiler bug
        acc *= acc_scale[:, None]
        acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
    # write back l and m
    acc = acc / l_i[:, None]
    # write back O
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, HIDDEN_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    tl.store(O_block_ptr, acc.to(K.dtype.element_ty), boundary_check=(0, 1))

def forward(q, k, v, sm_scale, sequence_parallel=False):
    # only support for Ampere now
    capability = torch.cuda.get_device_capability()
    if capability[0] < 8:
        raise RuntimeError(
            "Flash attention currently only supported for compute capability >= 80"
        )
    BLOCK_M = 128
    BLOCK_N = 64
    # shape constraints
    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
    assert Lq == Lk and Lk == Lv
    BLOCK_HEADDIM = max(triton.next_power_of_2(Lk), 16)
    # assert Lk in {16, 32, 64, 128}
    o = torch.empty_like(q)
    grid = (triton.cdiv(q.shape[0], BLOCK_M), q.shape[1], 1)

    num_warps = 4 if Lk <= 64 else 8
    _fwd_kernel2[grid](
        q,
        k,
        v,
        sm_scale,
        o,
        q.stride(1),
        q.stride(0),
        q.stride(2),
        k.stride(1),
        k.stride(0),
        k.stride(2),
        v.stride(1),
        v.stride(0),
        v.stride(2),
        o.stride(1),
        o.stride(0),
        o.stride(2),
        q.shape[0],
        k.shape[0],
        Lk,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_DMODEL=BLOCK_HEADDIM,
        num_warps=num_warps,
        num_stages=4,
    )

    return o

# @pytest.mark.parametrize("batch, seqlen_q, nheads, d,", [(1, 2, 1024, 64)])
# @pytest.mark.parametrize("causal", [True])
@torch.no_grad()
def test_op(seqlen, nheads, d, dtype=torch.float16, q_ctx=None):
    if q_ctx == None:
        q_ctx = seqlen
    device = "cuda"
    assert d <= 128, "FlashAttention only support head dimensions up to 128"
    torch.manual_seed(20)
    q = torch.empty((q_ctx, nheads, d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )
    k = torch.empty(( seqlen, nheads,d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )
    v = torch.empty((seqlen, nheads, d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )

    sm_scale = 0.5

    tri_out = forward(
        q.to(device),
        k.to(device),
        v.to(device),
        # pos_emb.to(device),
        sm_scale=sm_scale,
    ).to(dtype)
    # reference implementation

    dots = torch.matmul(q.transpose(0,1), k.transpose(0,1).transpose(-1, -2)) * sm_scale
    attn = torch.softmax(
        dots.float(), axis=-1
    ).half()
    ref_out = torch.matmul(attn, v.transpose(0,1).half()).detach().to(dtype).to(device).transpose(0,1)
    # triton implementation

    # compare
    # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
    print("max diff: ", (ref_out - tri_out).abs().max().item())

if __name__ == "__main__":
    test_op(4096, 8, 32, torch.float16)
    test_op(4096, 8, 32, torch.float16, q_ctx=100)
    test_op(4096, 800, 32, torch.float16, q_ctx=10)
max diff:  3.814697265625e-05
max diff:  3.0517578125e-05
max diff:  3.0517578125e-05

This version give correct output, but when I change the qkv layout from (seqlen, nheads, d) to (nheads, seqlen, d), I got wrong result.

import math

import pytest
import torch
import triton
import triton.language as tl

# from triton.runtime.interpreter import TensorHandle

@triton.jit
def _fwd_kernel2(
    Q,
    K,
    V,
    # POS_EMB2,
    sm_scale,
    Out,
    # stride_e0,
    # stride_e1,
    stride_qh,
    stride_qm,
    stride_qk,
    stride_kh,
    stride_kn,
    stride_kk,
    stride_vh,
    stride_vk,
    stride_vn,
    stride_oh,
    stride_om,
    stride_on,
    N_CTX,
    N_CTX_2,
    HIDDEN_DIM,
    # emb_len,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qvk_offset = off_hz * stride_qh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HIDDEN_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HIDDEN_DIM, N_CTX_2),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1),
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX_2, HIDDEN_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0),
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    # l_i = tl.where(offs_m < N_CTX, l_i, 1)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # credits to: Adam P. Goucher (https://github.com/apgoucher):
    # scale sm_scale by 1/log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr, padding_option="zero", boundary_check=(0, 1))
    q = (q * qk_scale).to(K.dtype.element_ty)
    lo = 0
    hi = N_CTX_2
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(K_block_ptr, padding_option="zero", boundary_check=(0, 1))
        v = tl.load(V_block_ptr, padding_option="zero", boundary_check=(0, 1))
        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # if IS_CAUSAL:
        #     qk = tl.where(
        #         offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")
        #     )
        qk += tl.dot(q, k, allow_tf32=True)
        # -- compute scaling constant ---
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        alpha = tl.math.exp2(m_i - m_i_new)
        p = tl.math.exp2(qk - m_i_new[:, None])
        # -- scale and update acc --
        acc_scale = l_i * 0 + alpha  # workaround some compiler bug
        acc *= acc_scale[:, None]
        acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True)
        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
    # write back l and m
    acc = acc / l_i[:, None]
    # write back O
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, HIDDEN_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    tl.store(O_block_ptr, acc.to(K.dtype.element_ty), boundary_check=(0, 1))

def forward(q, k, v, sm_scale, sequence_parallel=False):
    # only support for Ampere now
    capability = torch.cuda.get_device_capability()
    if capability[0] < 8:
        raise RuntimeError(
            "Flash attention currently only supported for compute capability >= 80"
        )
    BLOCK_M = 128
    BLOCK_N = 64
    # shape constraints
    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
    assert Lq == Lk and Lk == Lv
    BLOCK_HEADDIM = max(triton.next_power_of_2(Lk), 16)
    # assert Lk in {16, 32, 64, 128}
    o = torch.empty_like(q)
    grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)

    num_warps = 4 if Lk <= 64 else 8
    _fwd_kernel2[grid](
        q,
        k,
        v,
        sm_scale,
        o,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        q.shape[1],
        k.shape[1],
        Lk,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_DMODEL=BLOCK_HEADDIM,
        num_warps=num_warps,
        num_stages=4,
    )

    return o

# @pytest.mark.parametrize("batch, seqlen_q, nheads, d,", [(1, 2, 1024, 64)])
# @pytest.mark.parametrize("causal", [True])
@torch.no_grad()
def test_op(nheads, seqlen_q, d, dtype=torch.float16, q_ctx=None):
    if q_ctx == None:
        q_ctx = seqlen_q
    device = "cuda"
    assert d <= 128, "FlashAttention only support head dimensions up to 128"
    torch.manual_seed(20)
    q = torch.empty((nheads, q_ctx, d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )
    k = torch.empty((nheads, seqlen_q, d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )
    v = torch.empty((nheads, seqlen_q, d), dtype=dtype, device="cuda").normal_(
        mean=0.0, std=0.5
    )

    sm_scale = 0.5

    tri_out = forward(
        q.to(device),
        k.to(device),
        v.to(device),
        # pos_emb.to(device),
        sm_scale=sm_scale,
    ).to(dtype)
    # reference implementation

    dots = torch.matmul(q, k.transpose(-1, -2)) * sm_scale
    attn = torch.softmax(
        dots.float(), axis=-1
    ).half()
    ref_out = torch.matmul(attn, v.half()).detach().to(dtype).to(device)
    # triton implementation

    # compare
    # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
    print("max diff: ", (ref_out - tri_out).abs().max().item())

if __name__ == "__main__":
    test_op(8, 4096, 32, torch.float16)
    test_op(8, 4096, 32, torch.float16, q_ctx=100)
    test_op(800, 4096, 32, torch.float16, q_ctx=10)
    test_op(8, 4096, 32, torch.float16, q_ctx=4095)
max diff:  4.57763671875e-05
max diff:  0.060211181640625
max diff:  0.07171630859375
max diff:  0.0134429931640625

And it is very strange that when q_ctx is 4095 and 4096, the error difference is very large (' 0.0134429931640625 'vs.' 4.57763671875e-05 ').

yiakwy-xpu-ml-framework-team commented 6 months ago

Hi, @zhanglei1172 The standard input format is (batches, nheads, seqlen, d).

Note N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA. I think the kernel you used is mixing use of both :

Look at cuda blocks created :

grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)

You have created 1nheads sequence_length / BLOCK_M blocks

# Q is like a two dimension matrix, with rows **(1 x nheads x sequence_length/BLOCK_M)** and columns **emb_d**, s.t. Q=[q1, q2, ...]_T

# K is like another tow dimension matrix (we don't need to transpose it) with rows  **(1 x nheads x sequence_length/BLOCK_N)** and columns **emb_d**, s.t. K = [k1, k2, ...]_T

to calculate causal attention we load

q_i 

- iter 0 : k1, v1
- iter 1 :  k2, v2 
...
- iter hi : k_hi, v_hi (BLOCK_N, emb_d)

Can you check the source code version you used and explain N_CTX_2 in your case?

zhanglei1172 commented 6 months ago

Hi, I use triton==2.1.0. and N_CTX_2 means the sequence_length of K/V. N_CTX means the sequence_length of Q.

I check the tutorials: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html and https://github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py#L42 . Both codes use transpose.

And I try to change my second version above:

    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(N_CTX_2, HIDDEN_DIM),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0),
    )

    qk += tl.dot(q, tl.trans(k), allow_tf32=True)

Results:

max diff:  0.039794921875
max diff:  0.0565185546875
max diff:  0.06787109375
max diff:  0.04058837890625
yiakwy-xpu-ml-framework-team commented 6 months ago

@zhanglei1172 The code you referred is updated version with patch PR#2336. The definition of N_CTX_2 is here:

https://github.com/openai/triton/blob/72cba380aa64336c90cd98fc9c74cc5a5b205e05/python/triton/ops/flash_attention.py#L390

q.shape[0] q.shape[1] q.shape[2]

It is not the sequence_length of K or V or more precisely it treats Q, K, V as 2 dimension matrix, (-1, d), (d, -1), (d, -1), where axis -1 contains values repeated q.shape[0] q.shape[1] q.shape[2] times along sequence length direction.

So I think the mixed version of Flash attention v2 (online softmax with memory efficient attention, with Q loaded first in outter loop) and the attention with the support of Hoper TMA is used here.

zhanglei1172 commented 6 months ago

@yiakwy-xpu-ml-framework-team You mentioned that N_CTX_2 (batches x hiddens x seqence_length) is introduced in the PR to support Hopper TMA, but the version of triton I'm using(or reference) doesn't reference the code containing this PR. Instead, the reason I introduced N_CTX_2(I set it up myself. I didn't refer the code on that PR about Hoper TMA) is for use in scenarios dealing with cross attention (Q and K/V have inconsistent sequence lengths). The code I used intends to differ from the original attention in two main ways: permulation of input shape and the sequence lengths of Q and K/V are different

So I ultimately want to make sure if the current Triton can't support this kind of permulation or cross attention, and if it does, then can I modify the original code to get the results correctly?