Open szmigacz opened 1 year ago
Rewriting tl.dot to compute dq in sequence-parallel Flash Attention Backward from: dq = tl.dot(ds, k) into: dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) causes 7.8x speedup
dq
dq = tl.dot(ds, k)
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))
Minimal Repro:
import torch import triton import triton.language as tl @triton.jit def _bwd_kernel( Q, K, V, DO, DQ, DK, DV, L, M, D, softmax_scale, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_ok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvk, stride_dvn, stride_lz, stride_lh, stride_lm, seq_len, BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, TRANSPOSED: tl.constexpr, ): off_z = tl.program_id(0) off_h = tl.program_id(1) start_n = tl.program_id(2) off_hz = off_z * tl.num_programs(1) + off_h # offset pointers for batch/head K += off_z * stride_kz + off_h * stride_kh V += off_z * stride_vz + off_h * stride_vh DK += off_z * stride_dkz + off_h * stride_dkh DV += off_z * stride_dvz + off_h * stride_dvh Q += off_z * stride_qz + off_h * stride_qh DO += off_z * stride_oz + off_h * stride_oh DQ += off_z * stride_dqz + off_h * stride_dqh # initialize row/col offsets offs_qm = tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_DMODEL) # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) do_ptrs = DO + (offs_qm[:, None] * stride_om + offs_k[None, :] * stride_ok) dq_ptrs = DQ + ( offs_qm[:, None] * stride_dqm + offs_k[None, :] * stride_dqk ) D_ptrs = D + off_hz * stride_lh m_ptrs = M + off_hz * stride_lh dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) k = tl.load(k_ptrs) v = tl.load(v_ptrs) num_block = tl.cdiv(seq_len, BLOCK_M) for start_m in range(0, num_block * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m q = tl.load(q_ptrs) qk = tl.dot(q, tl.trans(k)) qk *= softmax_scale m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk - m[:, None]) do = tl.load(do_ptrs) dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.dot(do, tl.trans(v)) ds = (p * (dp - Di[:, None]) * softmax_scale).to(Q.dtype.element_ty) dk += tl.dot(tl.trans(ds), q) if not TRANSPOSED: dq = tl.dot(ds, k) else: dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) tl.atomic_add(dq_ptrs, dq) dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_om dv_ptrs = DV + ( offs_n[:, None] * stride_dvk + offs_k[None, :] * stride_dvn ) dk_ptrs = DK + ( offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk ) tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) def run(transposed): device = torch.device('cuda') dtype = torch.float16 batch = 2 num_heads = 16 seq_len = 2048 d_head = 128 softmax_scale = 1.1 iterations = 20 warmup = 5 q = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype) k = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype) v = torch.rand(batch, num_heads, seq_len, d_head, device=device, dtype=dtype) do = torch.rand_like(q) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) l = torch.ones(batch, num_heads, seq_len, device=device, dtype=torch.float32) m = torch.ones(batch, num_heads, seq_len, device=device, dtype=torch.float32) delta = torch.ones_like(l) grid = lambda META: ( batch, num_heads, triton.cdiv(seq_len, META['BLOCK_N']), ) for i in range(iterations): if i == warmup: torch.cuda.synchronize() start = time.perf_counter() bwd = _bwd_kernel[grid]( q, k, v, do, dq, dk, dv, l, m, delta, softmax_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), l.stride(0), l.stride(1), l.stride(2), seq_len, BLOCK_DMODEL=d_head, TRANSPOSED=transposed, BLOCK_M=128, BLOCK_N=128, num_warps=8, num_stages=1, ) torch.cuda.synchronize() stop = time.perf_counter() elapsed = 1000 * (stop - start) / (iterations - warmup) print(f'Transposed: {transposed}, elapsed: {elapsed}') run(False) run(True)
The code above executes backward pass first with TRANSPOSED set to False, then with TRANSPOSED set to True and reports elapsed time:
Transposed: False, elapsed: 48.10845733154565 Transposed: True, elapsed: 6.173196334081392
Code with extra transposes is 7.8x faster on A100.
Triton: latest main (4be1c94b1f0d1f07f003ee5d38af943070110688) GPU: NVIDIA A100-SXM4-80GB nvcc: V11.7.99 pytorch: 2.0.1 (pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel container from pytorch dockerhub)
I am curious about the performance of this seeping up; could someone explain why?
Rewriting tl.dot to compute
dq
in sequence-parallel Flash Attention Backward from:dq = tl.dot(ds, k)
into:dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))
causes 7.8x speedupMinimal Repro:
The code above executes backward pass first with TRANSPOSED set to False, then with TRANSPOSED set to True and reports elapsed time:
Code with extra transposes is 7.8x faster on A100.
Triton: latest main (4be1c94b1f0d1f07f003ee5d38af943070110688) GPU: NVIDIA A100-SXM4-80GB nvcc: V11.7.99 pytorch: 2.0.1 (pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel container from pytorch dockerhub)