triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.12k stars 1.6k forks source link

tl.dot with 3D shapes compilation error #4867

Open jubueche opened 1 week ago

jubueche commented 1 week ago

This code

import triton
import triton.language as tl
from torch import Tensor

@triton.jit
def get_three_d_weights(
    currents,  # [B_M, B_N]
    weight_block,  # [B_K, B_N]
    BLOCK_SIZE_K: tl.constexpr,
):
    poly_arange = tl.arange(0, BLOCK_SIZE_K)[:, None]
    # [B_K,1] x [B_M,1,B_N] -> [B_M, B_ ]
    approx_v = poly_arange * currents[:, None, :]

    # this will broadcast to [BLOCK_SIZE_M*next_pow_2(n_inp_bits),BLOCK_SIZE_K,BLOCK_SIZE_N]
    ir_dropped_weights = weight_block[None, :, :] * approx_v
    return ir_dropped_weights

@triton.jit
def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, currents_ptr,  #
                             M, N, K,  #
                             stride_am, stride_ak,  #
                             stride_bk, stride_bn,  #
                             stride_cm, stride_cn,  #
                             stride_currentsm, stride_currentsn,  #
                             BLOCK_SIZE_M: tl.constexpr,  #
                             BLOCK_SIZE_N: tl.constexpr,  #
                             BLOCK_SIZE_K: tl.constexpr,  #
                             GROUP_SIZE_M: tl.constexpr,  #
                             NUM_SMS: tl.constexpr,  #
                             ):
    start_pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
    num_tiles = num_pid_m * num_pid_n

    tiles_per_SM = num_tiles // NUM_SMS
    if start_pid < num_tiles % NUM_SMS:
        tiles_per_SM += 1

    tile_id = start_pid - NUM_SMS
    ki = -1

    offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)

    num_pid_in_group = GROUP_SIZE_M * num_pid_n

    pid_m = 0
    pid_n = 0
    offs_am = tl.arange(0, BLOCK_SIZE_M)
    offs_bn = tl.arange(0, BLOCK_SIZE_N)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    currents = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)

    for _ in range(0, k_tiles * tiles_per_SM):
        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
        if ki == 0:
            tile_id += NUM_SMS
            group_id = tile_id // num_pid_in_group
            first_pid_m = group_id * GROUP_SIZE_M
            group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
            pid_m = first_pid_m + (tile_id % group_size_m)
            pid_n = (tile_id % num_pid_in_group) // group_size_m

            start_m = pid_m * BLOCK_SIZE_M
            start_n = pid_n * BLOCK_SIZE_N
            offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
            offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
            offs_am = tl.where(offs_am < M, offs_am, 0)
            offs_bn = tl.where(offs_bn < N, offs_bn, 0)
            offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
            offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

            currents_ptrs = currents_ptr + (offs_am[:, None] * stride_currentsm + offs_bn[None, :] * stride_currentsn)
            currents = tl.load(currents_ptrs, mask=(offs_am[:, None] < M) & (offs_bn[None, :] < N), other=0.0)

        offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
        a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

        a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)

        modified_weights = get_three_d_weights(currents, b, BLOCK_SIZE_K)

        a_broadcasted = tl.broadcast_to(a[:, None], (BLOCK_SIZE_M, 16, BLOCK_SIZE_K))
        three_d_dot = tl.dot(a_broadcasted, modified_weights)
        three_d_dot = tl.sum(three_d_dot, axis=1) / 16.

        accumulator += three_d_dot

        if ki == k_tiles - 1:
            offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
            offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
            c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
            c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
            if (c_ptr.dtype.element_ty == tl.float8e4nv):
                c = accumulator.to(tl.float8e4nv)
            else:
                c = accumulator.to(tl.float16)
            tl.store(c_ptrs, c, mask=c_mask)
            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# fmt: on

def matmul_persistent(a: Tensor, b: Tensor, currents: Tensor):
    configs = {
        torch.float8_e4m3fn: {
            "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
            "num_warps": 8
        },
        torch.float16: {
            "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 8, "num_stages": 3,
            "num_warps": 8
        }
    }
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.shape[0] == currents.shape[0], "Wrong current shape"
    assert currents.shape[1] == b.shape[1], "Wrong current shape"
    assert a.dtype == b.dtype, "Incompatible dtypes"
    assert a.dtype == currents.dtype, "Incompatible dtypes"
    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
    M, K = a.shape
    K, N = b.shape
    dtype = a.dtype
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
    matmul_kernel_persistent[grid](
        a, b, c, currents,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        currents.stride(0), currents.stride(1),  #
        BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"],  #
        BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"],  #
        BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"],  #
        GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"],  #
        NUM_SMS=NUM_SMS,  #
        num_stages=configs[dtype]["num_stages"],  #
        num_warps=configs[dtype]["num_warps"],  #
    )
    return c

if __name__ == "__main__":
    import torch

    matmul_persistent(
        torch.randn((10, 4), device="cuda", dtype=torch.float16),
        torch.randn((4, 5), device="cuda", dtype=torch.float16),
        torch.randn((10, 5), device="cuda", dtype=torch.float16),
    )

throws the following compilation error:

python: /root/.triton/llvm/llvm-ed4e505c-centos-x64/include/llvm/ADT/SmallVector.h:304: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = unsigned int; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = unsigned int&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.

I don't know how to debug this (a tutorial on this or pointers would be nice). I am using triton v3.0.0. installed via pip from the webpage.

Edenzzzz commented 1 week ago
image
jubueche commented 1 week ago

This takes me to the page detailing debugging the frontend. This code runs when I run it with TRITON_INTERPRET=1. So I don't think your "suggestion" works here. I think it's a bug so I am opening an issue here.