triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.78k stars 1.54k forks source link

Unexpected mma -> mma layout conversion #1420

Open bchetioui opened 1 year ago

bchetioui commented 1 year ago

I am trying to reimplement Praxis's dot product attention with a lazy broadcast prefix.

My code is the following:

import numpy as np
import torch
import triton
import triton.language as tl

def strides_length(shape):
  size = np.prod(shape)
  for s in shape:
    size = size // s
    yield int(size)

@triton.jit
def load_one_of(a, b, tile, threshold):
  return tl.load(tl.where(tile < threshold, a, b))

@triton.jit
def fused_attention_with_lazy_broadcast_prefix_kernel(
  # inputs
  Q, K_prefix, K_suffix, V_prefix, V_suffix, mask,
  # dimensions
  batch_size: tl.constexpr, seq_len: tl.constexpr, num_heads: tl.constexpr,
  head_dim: tl.constexpr, prefix_size: tl.constexpr, suffix_size: tl.constexpr,
  # outputs
  L, M, encoded,
  # block information
  block_q: tl.constexpr, block_k: tl.constexpr, block_d: tl.constexpr
):
  """Reimplementation of Praxis's lazy broadcast prefix attention.

  See https://github.com/google/praxis/blob/main/praxis/layers/attentions.py#L2068.
  """
  # note: block_d is here assumed to be == to head_dim

  head_stride = head_dim
  block_stride = num_heads * head_stride
  batch_stride = seq_len * block_stride

  block_index = tl.program_id(0)
  head_index = tl.program_id(1)
  batch_index = tl.program_id(2)

  # Load q tile
  q_tile = batch_index * batch_stride \
         + (block_index * block_q + tl.arange(0, block_q))[:, None] * block_stride \
         + head_index * head_stride \
         + tl.arange(0, block_d)[None, :]

  q = tl.load(Q + q_tile)

  # Prepare k and v tiles
  initial_kv_tile = batch_index * batch_stride \
                  + tl.arange(0, block_k)[:, None] * block_stride \
                  + head_index * head_stride \
                  + tl.arange(0, block_d)[None, :]
  kv_tile = initial_kv_tile

  nb_kv_prefix_elements = batch_size * prefix_size * num_heads * head_dim

  # Initialize local SRAM buffers
  m_i = tl.zeros([block_q], dtype=tl.float32) - float('inf')
  l_i = tl.zeros([block_q], dtype=tl.float32)
  acc = tl.zeros([block_q, block_d], dtype=tl.float32)

  for i in range(0, seq_len // block_k):
    start_k = tl.multiple_of(i * block_k, block_k)
    # -- compute qk --
    k = load_one_of(K_prefix + kv_tile, K_suffix + kv_tile, kv_tile,
                    nb_kv_prefix_elements)
    qk = tl.dot(q, tl.trans(k))
    # -- apply attention mask --
    # mask_tile =
    # -- apply row softmax --
    m_ij = tl.max(qk, axis=1)
    probs_ij_softmax_exponents = qk - m_ij[:, None]
    probs_ij = tl.exp(probs_ij_softmax_exponents)
    l_ij = tl.sum(probs_ij, axis=1)

    # -- construct new m_i and l_i --
    m_i_new = tl.maximum(m_i, m_ij)
    alpha, beta = tl.exp(m_i - m_i_new), tl.exp(m_ij - m_i_new)
    l_i_new = alpha * l_i + beta * l_ij

    # -- update accumulator --
    probs_scale = beta / l_i_new
    probs_ij = probs_ij * probs_scale[:, None]

    acc_scale = l_i / l_i_new * alpha
    acc *= acc_scale[:, None]
    v = load_one_of(V_prefix + kv_tile, V_suffix + kv_tile, kv_tile,
                    nb_kv_prefix_elements)
    acc += tl.dot(probs_ij, v)

    # -- update m_i and l_i --
    m_i, l_i = m_i_new, l_i_new

    # -- update kv tile --
    kv_tile += initial_kv_tile

  # Write back L, M, encoded
  lm_head_stride = seq_len
  lm_batch_stride = num_heads * lm_head_stride
  lm_tile = batch_index * lm_batch_stride \
          + head_index * lm_head_stride \
          + (block_index * block_q + tl.arange(0, block_q))

  tl.store(L + lm_tile, l_i)
  tl.store(M + lm_tile, m_i)
  tl.store(encoded + q_tile, acc)

def shape_matches(shape, target_shape):
  if len(shape) != len(target_shape): return False
  for component, target in zip(shape, target_shape):
    if target != -1 and component != target: return False

  return True

def fused_attention_with_lazy_prefix_broadcast(q, k_prefix, k_suffix,
                                               v_prefix, v_suffix, mask,
                                               block_q: int = 64,
                                               block_k: int = 64):
  """Implementation of FlashAttention with lazy prefix broadcast.

  The implementation is based on Pallas's MHA implementation, with handling for
  lazy prefix broadcast.

  See https://github.com/jax-ml/jax-triton/blob/mlir/jax_triton/pallas/ops/attention.py#L93
  and https://github.com/google/praxis/blob/main/praxis/layers/attentions.py#L2068.
  """
  batch_size, seq_len, num_heads, head_dim = q.shape
  assert shape_matches(k_prefix.shape, (batch_size, -1, num_heads, head_dim))
  assert shape_matches(k_suffix.shape, (batch_size, -1, num_heads, head_dim))
  assert shape_matches(v_prefix.shape, k_prefix.shape)
  assert shape_matches(v_suffix.shape, k_suffix.shape)

  block_q = min(block_q, seq_len)
  block_k = min(block_k, seq_len)
  grid = (triton.cdiv(seq_len, block_q), num_heads, batch_size)

  encoded = torch.empty_like(q)

  l = torch.empty((batch_size, num_heads, seq_len), device=q.device,
                  dtype=torch.float32)
  m = torch.empty((batch_size, num_heads, seq_len), device=q.device,
                  dtype=torch.float32)

  metaparams = dict(
      block_q=block_q,
      block_d=head_dim,
      block_k=block_k,
      num_warps=4 if head_dim <= 64 else 8,
      num_stages=2)

  fused_attention_with_lazy_broadcast_prefix_kernel[grid](
      q, k_prefix, k_suffix, v_prefix, v_suffix, mask,
      *q.shape, k_prefix.shape[1], k_suffix.shape[1],
      l, m, encoded,
      **metaparams)

  return encoded, l, m

Unfortunately, I encounter the following error when attempting to compile on A100 with Triton at HEAD---apparently, pretty much independently of the input shapes:

triton/lib/Analysis/Allocation.cpp:41: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(const mlir::Attribute&, const mlir::Attribute&): Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mma layout conversion"' failed.

I read from issue #1298 that this may happen when the optimizer doesn't do its job well. Is there a known workaround that might work in this case? If not, I'd be happy to help with debugging this if someone could point me in the right direction :)

Jokeren commented 1 year ago

A workaround is to store temporary values to the global memory and then reload from it.

bchetioui commented 1 year ago

Great, thanks! That unblocks me for the time being.

delibae commented 1 year ago

@bchetioui Can you show me final code?

bchetioui commented 1 year ago

Sorry @delibae, I saw your message and couldn't reply at the time---and then forgot about it. Unfortunately, I do not have code that I can share for this at the moment. Do you still need it?

cyh-ustc commented 1 month ago

A workaround is to store temporary values to the global memory and then reload from it.

i tried to use tl.save with tl.load to bypass this issue it can compile and run but it seems that the saving doesn't work properly the stride info is not processed correctly so that some columns are missing t