triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.12k stars 1.61k forks source link

Extra transposes cause 8x speedup for sequence-parallel Flash Attention backward #1806

Open szmigacz opened 1 year ago

szmigacz commented 1 year ago

Rewriting tl.dot to compute dq in sequence-parallel Flash Attention Backward from: dq = tl.dot(ds, k) into: dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) causes 7.8x speedup

Minimal Repro:

import torch
import triton
import triton.language as tl

@triton.jit
def _bwd_kernel(
    Q, K, V,
    DO,
    DQ, DK, DV,
    L, M,
    D,
    softmax_scale,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_ok,
    stride_dqz, stride_dqh, stride_dqm, stride_dqk,
    stride_dkz, stride_dkh, stride_dkn, stride_dkk,
    stride_dvz, stride_dvh, stride_dvk, stride_dvn,
    stride_lz, stride_lh, stride_lm,
    seq_len,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    TRANSPOSED: tl.constexpr,
):
    off_z = tl.program_id(0)
    off_h = tl.program_id(1)
    start_n = tl.program_id(2)
    off_hz = off_z * tl.num_programs(1) + off_h

    # offset pointers for batch/head
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    DK += off_z * stride_dkz + off_h * stride_dkh
    DV += off_z * stride_dvz + off_h * stride_dvh
    Q += off_z * stride_qz + off_h * stride_qh
    DO += off_z * stride_oz + off_h * stride_oh
    DQ += off_z * stride_dqz + off_h * stride_dqh

    # initialize row/col offsets
    offs_qm = tl.arange(0, BLOCK_M)
    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
    offs_k = tl.arange(0, BLOCK_DMODEL)

    # initialize pointers to value-like data
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
    v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)
    do_ptrs = DO + (offs_qm[:, None] * stride_om + offs_k[None, :] * stride_ok)
    dq_ptrs = DQ + (
        offs_qm[:, None] * stride_dqm + offs_k[None, :] * stride_dqk
    )

    D_ptrs = D + off_hz * stride_lh
    m_ptrs = M + off_hz * stride_lh

    dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)

    k = tl.load(k_ptrs)
    v = tl.load(v_ptrs)

    num_block = tl.cdiv(seq_len, BLOCK_M)
    for start_m in range(0, num_block * BLOCK_M, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)
        offs_m_curr = start_m + offs_m

        q = tl.load(q_ptrs)
        qk = tl.dot(q, tl.trans(k))
        qk *= softmax_scale
        m = tl.load(m_ptrs + offs_m_curr)
        p = tl.exp(qk - m[:, None])
        do = tl.load(do_ptrs)
        dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
        Di = tl.load(D_ptrs + offs_m_curr)
        dp = tl.dot(do, tl.trans(v))
        ds = (p * (dp - Di[:, None]) * softmax_scale).to(Q.dtype.element_ty)
        dk += tl.dot(tl.trans(ds), q)

        if not TRANSPOSED:
            dq = tl.dot(ds, k)
        else:
            dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))

        tl.atomic_add(dq_ptrs, dq)

        dq_ptrs += BLOCK_M * stride_dqm
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_om

    dv_ptrs = DV + (
        offs_n[:, None] * stride_dvk + offs_k[None, :] * stride_dvn
    )
    dk_ptrs = DK + (
        offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk
    )
    tl.store(dv_ptrs, dv)
    tl.store(dk_ptrs, dk)

def run(transposed):
    device = torch.device('cuda')
    dtype = torch.float16
    batch = 2
    num_heads = 16
    seq_len = 2048
    d_head = 128
    softmax_scale = 1.1
    iterations = 20
    warmup = 5

    q = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype)
    k = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype)
    v = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype)

    do = torch.rand_like(q)
    dq = torch.zeros_like(q, dtype=torch.float32)
    dk = torch.empty_like(k)
    dv = torch.empty_like(v)

    l = torch.ones(batch, num_heads, seq_len, device=device, dtype=torch.float32)
    m = torch.ones(batch, num_heads, seq_len, device=device, dtype=torch.float32)
    delta = torch.ones_like(l)

    grid = lambda META: (
        batch,
        num_heads,
        triton.cdiv(seq_len, META['BLOCK_N']),
    )

    for i in range(iterations):
        if i == warmup:
            torch.cuda.synchronize()
            start = time.perf_counter()

        bwd = _bwd_kernel[grid](
            q, k, v,
            do, 
            dq, dk, dv,
            l, m,
            delta,
            softmax_scale,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            do.stride(0), do.stride(1), do.stride(2), do.stride(3),
            dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
            dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
            dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
            l.stride(0), l.stride(1), l.stride(2),
            seq_len,
            BLOCK_DMODEL=d_head,
            TRANSPOSED=transposed,
            BLOCK_M=128,
            BLOCK_N=128,
            num_warps=8,
            num_stages=1,
        )
    torch.cuda.synchronize()
    stop = time.perf_counter()
    elapsed = 1000 * (stop - start) / (iterations - warmup)
    print(f'Transposed: {transposed}, elapsed: {elapsed}')

run(False)
run(True)

The code above executes backward pass first with TRANSPOSED set to False, then with TRANSPOSED set to True and reports elapsed time:

Transposed: False, elapsed: 48.10845733154565
Transposed: True, elapsed: 6.173196334081392

Code with extra transposes is 7.8x faster on A100.

Triton: latest main (4be1c94b1f0d1f07f003ee5d38af943070110688) GPU: NVIDIA A100-SXM4-80GB nvcc: V11.7.99 pytorch: 2.0.1 (pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel container from pytorch dockerhub)

jt-zhang commented 6 months ago

I am curious about the performance of this seeping up; could someone explain why?