triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.36k stars 1.64k forks source link

"Invalid insertelement operands!" for bfloat16 #635

Open giorgio-arena opened 2 years ago

giorgio-arena commented 2 years ago

Hi, I'm trying to run the forward pass of the fused attention example using bfloat16 dtype for tensors instead of float16, and I'm getting a lot of errors like this one

Invalid insertelement operands!
  %4813 = insertelement <2 x half> undef, i16 %4612, i32 0
Invalid insertelement operands!
  %4814 = insertelement <2 x half> %4813, i16 %4613, i32 1
...
in function _fwd_kernel__Pbf16_Pbf16_Pbf16_fp32_Pfp32_Pfp32_Pfp32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32_i32__11c1_15c1_19c1_23c1_27c128_28c64_29c128
LLVM ERROR: Broken function found, compilation aborted!
Aborted

Here's a reproducer, but it's essentially the forward pass of the fused attention example with the dtype changed to bfloat16

import torch
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel(
    Q, K, V, sm_scale,
    TMP, L, M,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
  start_m = tl.program_id(0)
  off_hz = tl.program_id(1)
  # initialize offsets
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  offs_n = tl.arange(0, BLOCK_N)
  offs_d = tl.arange(0, BLOCK_DMODEL)
  off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
  off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
  # Initialize pointers to Q, K, V
  q_ptrs = Q + off_q
  k_ptrs = K + off_k
  v_ptrs = V + off_v
  # initialize pointer to m and l
  t_ptrs = TMP + off_hz * N_CTX + offs_m
  m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  # load q: it will stay in SRAM throughout
  q = tl.load(q_ptrs)
  # loop over k, v and update accumulator
  for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
      start_n = tl.multiple_of(start_n, BLOCK_N)
      # -- compute qk ----
      k = tl.load(k_ptrs + start_n * stride_kn)
      qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
      qk += tl.dot(q, k, trans_b=True)
      qk *= sm_scale
      qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
      # -- compute m_ij, p, l_ij
      m_ij = tl.max(qk, 1)
      p = tl.exp(qk - m_ij[:, None])
      l_ij = tl.sum(p, 1)
      # -- update m_i and l_i
      m_i_new = tl.maximum(m_i, m_ij)
      alpha = tl.exp(m_i - m_i_new)
      beta = tl.exp(m_ij - m_i_new)
      l_i_new = alpha * l_i + beta * l_ij
      # -- update output accumulator --
      # scale p
      p_scale = beta / l_i_new
      p = p * p_scale[:, None]
      # scale acc
      acc_scale = l_i / l_i_new * alpha
      tl.store(t_ptrs, acc_scale)
      acc_scale = tl.load(t_ptrs)  # BUG: have to store and immediately load
      acc = acc * acc_scale[:, None]
      # update acc
      v = tl.load(v_ptrs + start_n * stride_vk)
      p = p.to(tl.bfloat16)
      acc += tl.dot(p, v)
      # update m_i and l_i
      l_i = l_i_new
      m_i = m_i_new
  # rematerialize offsets to save registers
  start_m = tl.program_id(0)
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  # write back l and m
  l_ptrs = L + off_hz * N_CTX + offs_m
  m_ptrs = M + off_hz * N_CTX + offs_m
  tl.store(l_ptrs, l_i)
  tl.store(m_ptrs, m_i)
  # initialize pointers to output
  offs_n = tl.arange(0, BLOCK_DMODEL)
  off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
  out_ptrs = Out + off_o
  tl.store(out_ptrs, acc)

class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, sm_scale):
        BLOCK = 128
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
        tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        num_warps = 4 if Lk <= 64 else 8

        _fwd_kernel[grid](
          q, k, v, sm_scale,
          tmp, L, m,
          o,
          q.stride(0), q.stride(1), q.stride(2), q.stride(3),
          k.stride(0), k.stride(1), k.stride(2), k.stride(3),
          v.stride(0), v.stride(1), v.stride(2), v.stride(3),
          o.stride(0), o.stride(1), o.stride(2), o.stride(3),
          q.shape[0], q.shape[1], q.shape[2],
          BLOCK_M=BLOCK, BLOCK_N=BLOCK,
          BLOCK_DMODEL=Lk, num_warps=num_warps,
          num_stages=1,
        )

        ctx.save_for_backward(q, k, v, o, L, m)
        ctx.BLOCK = BLOCK
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        return o

attention = _attention.apply

if __name__ == '__main__':
    Z, H, N_CTX, D_HEAD, dtype = 3, 2, 2048, 64, torch.bfloat16

    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
    sm_scale = 0.3
    tri_out = attention(q, k, v, sm_scale)

Running on an RTX A5000 with driver version 470.129.06, using CUDA toolkit version 11.0. I would attach the PTX, but I guess it doesn't generate any since LLVM compilation fails. Any help would be appreciated, thanks.

daadaada commented 2 years ago

636 should have resolved this