triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.12k stars 1.6k forks source link

Assertion `0 && "getThreadsPerWarp not implemented"' failed #1193

Open latkins opened 1 year ago

latkins commented 1 year ago

I am getting the following error, which I am unsure how to debug:

python: /projects/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp:85: llvm::SmallVector<unsigned int> mlir::triton::gpu::getThreadsPerWarp(const mlir::Attribute&): Assertion `0 && "getThreadsPerWarp not implemented"' failed.
fish: Job 1, 'python error.py' terminated by signal SIGABRT (Abort)

While trying to produce a minimal example (code below), I removed line 91 (acc += tl.sum(p_jik[:, :, :, None] * v_ik_data[None, :, :, :], 2)), which caused the python process to hang for >1h at 100% CPU usage, before crashing.

Triton is installed from 6413c7b9debf9e82b9f2df4dc1688a66427b8064.

Code to reproduce:

import torch
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    out_ptr,
    i_max,
    j_max,
    k_max,
    d_max,
    q_j_stride,
    q_i_stride,
    q_d_stride,
    k_i_stride,
    k_k_stride,
    k_d_stride,
    v_i_stride,
    v_k_stride,
    v_d_stride,
    o_j_stride,
    o_i_stride,
    o_d_stride,
    # Block sizes
    BLOCK_I: tl.constexpr,
    BLOCK_J: tl.constexpr,
    BLOCK_K: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    i_start = tl.program_id(0) * BLOCK_I

    j_start = tl.program_id(1) * BLOCK_J

    i_range_offs = tl.arange(0, BLOCK_I)
    j_range_offs = tl.arange(0, BLOCK_J)
    k_range_offs = tl.arange(0, BLOCK_K)
    d_range_offs = tl.arange(0, BLOCK_D)

    q_offs = (
        (j_start + j_range_offs[:, None, None]) * q_j_stride
        + (i_start + i_range_offs[None, :, None]) * q_i_stride
        + (d_range_offs[None, None, :] * q_d_stride)
    )

    # q_data: [BLOCK_J, BLOCK_I, BLOCK_D]
    q_ji_data = tl.load(q_ptr + q_offs)

    acc = tl.zeros([BLOCK_J, BLOCK_I, BLOCK_D], dtype=tl.float32)

    l_ji = tl.zeros([BLOCK_J, BLOCK_I], dtype=tl.float32)
    m_ji = tl.zeros([BLOCK_J, BLOCK_I], dtype=tl.float32) - float("inf")

    for k_start in range(0, k_max, BLOCK_K):
        k_offs = (
            (i_start + i_range_offs[:, None, None]) * k_i_stride
            + (k_start + k_range_offs[None, :, None]) * k_k_stride
            + (d_range_offs[None, None, :] * k_d_stride)
        )

        # k_data: [BLOCK_I, BLOCK_K, BLOCK_D]
        k_ik_data = tl.load(k_ptr + k_offs)

        s_jik = tl.sum(q_ji_data[:, :, None, :] * k_ik_data[None, :, :, :], axis=3)

        m_ji_tilde = tl.max(s_jik, axis=2)
        p_jik = tl.exp(s_jik - m_ji_tilde)

        l_ji_tilde = tl.sum(p_jik, 2)
        m_ji_new = tl.maximum(m_ji_tilde, m_ji)

        alpha = tl.exp(m_ji - m_ji_new)
        beta = tl.exp(m_ji_tilde - m_ji_new)
        l_ji_new = (alpha * l_ji) + (beta * l_ji_tilde)

        l_ji_new_rcp = 1.0 / l_ji_new[:, :, None]

        v_offs = (
            ((i_start + i_range_offs[:, None, None]) * v_i_stride)
            + ((k_start + k_range_offs[None, :, None]) * v_k_stride)
            + (d_range_offs[None, None, :] * v_d_stride)
        )

        v_ik_data = tl.load(v_ptr + v_offs)

        acc *= alpha * l_ji_new_rcp
        p_jik *= beta * l_ji_new_rcp
        acc += tl.sum(p_jik[:, :, :, None] * v_ik_data[None, :, :, :], 2)

        m_ji = m_ji_new
        l_ji = l_ji_new

    o_offs = (
        (j_start + j_range_offs[:, None, None]) * o_j_stride
        + (i_start + i_range_offs[None, :, None]) * o_i_stride
        + (d_range_offs[None, None, :] * o_d_stride)
    )

    tl.store(out_ptr + o_offs, acc)

def fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
):
    assert q.shape == k.shape

    N, M, D = q.shape

    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()

    def grid(META):
        return (
            triton.cdiv(N, META["BLOCK_I"]),
            triton.cdiv(M, META["BLOCK_J"]),
        )

    out = torch.zeros((M, N, D), dtype=torch.float32, device="cuda")

    result = _fwd_kernel[grid](
        q,
        k,
        v,
        out,
        N,
        M,
        M,  # N == M == K
        D,
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *out.stride(),
        BLOCK_I=8,
        BLOCK_J=16,
        BLOCK_K=32,
        BLOCK_D=D,
    )

    return out.transpose(0, 1)

if __name__ == "__main__":
    N = 32
    D = 4

    dtype = torch.float16
    device = "cuda"

    q = torch.empty((N, N, D), dtype=dtype, device=device).normal_(mean=0, std=0.5)
    k = torch.empty((N, N, D), dtype=dtype, device=device).normal_(mean=0, std=0.5)
    v = torch.empty((N, N, D), dtype=dtype, device=device).normal_(mean=0, std=0.5)

    fwd(q, k, v)
Jokeren commented 1 year ago

nd tensors (n >=3) are not well supported yet. Will investigate