microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
197 stars 46 forks source link

[Bug]: matmul functional incorrectness w/ CPU backend #184

Open ajwaitz opened 1 month ago

ajwaitz commented 1 month ago

Triton python code

import torch

import importlib

import triton
importlib.reload(triton)

# import triton.triton_shared

from triton.backends.triton_shared.driver import CPUDriver
import triton.language as tl

triton.runtime.driver.set_active(CPUDriver())

# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
# @triton.autotune(
#     configs=[
#         triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
#                       num_warps=8),
#         triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
#                       num_warps=4),
#         triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
#                       num_warps=4),
#         triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
#                       num_warps=4),
#         triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
#                       num_warps=4),
#         triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
#                       num_warps=4),
#         triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
#                       num_warps=2),
#         triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
#                       num_warps=2),
#     ],
#     key=['M', 'N', 'K'],
# )
@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
        ACTIVATION: tl.constexpr  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n 
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % GROUP_SIZE_M)
    # pid_m = first_pid_m + (pid % group_size_m) 
    pid_n = (pid % num_pid_in_group) // GROUP_SIZE_M
    # pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetics` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float32)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
    x = x + 1
    return tl.where(x >= 0, x, 0.01 * x)

def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation,  #
        # sweep block sizes
        # 1, 8, 16, 32, 64
        # question: does the error we see happen in lowering or CPU backend?
        BLOCK_SIZE_M=32,
        BLOCK_SIZE_N=32,
        BLOCK_SIZE_K=32,
        GROUP_SIZE_M=8
    )
    return c

def test_matmul():
    torch.manual_seed(0)
    rows1 = 64 * 8
    cols1 = 64 * 8
    rows2 = 64 * 8
    cols2 = 64 * 8
    a = torch.randn((rows1, cols1), device='cpu', dtype=torch.float32)
    b = torch.randn((rows2, cols2), device='cpu', dtype=torch.float32)
    # a = torch.full((rows1, cols1), 1, device='cpu', dtype=torch.float32)
    # b = torch.full((rows2, cols2), 1, device='cpu', dtype=torch.float32)
    triton_output = matmul(a, b)
    torch_output = torch.matmul(a, b)
    print(f"triton_output={triton_output}")
    print(f"torch_output={torch_output}")
    if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

if __name__ == "__main__":
    test_matmul()

Triton IR

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c64 = arith.constant 64 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %0 = tensor.empty() : tensor<32x64xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
    %2 = arith.addi %arg3, %c31_i32 : i32
    %3 = arith.divsi %2, %c32_i32 : i32
    %4 = arith.addi %arg4, %c63_i32 : i32
    %5 = arith.divsi %4, %c64_i32 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.divsi %arg12, %6 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = arith.subi %3, %8 : i32
    %10 = arith.minsi %9, %c8_i32 : i32
    %11 = arith.remsi %arg12, %10 : i32
    %12 = arith.addi %8, %11 : i32
    %13 = arith.remsi %arg12, %6 : i32
    %14 = arith.divsi %13, %10 : i32
    %15 = arith.muli %12, %c32_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.muli %14, %c64_i32 : i32
    %18 = arith.index_cast %17 : i32 to index
    %19 = arith.index_cast %arg3 : i32 to index
    %20 = arith.index_cast %arg6 : i32 to index
    %21 = arith.muli %16, %20 : index
    %22 = arith.muli %19, %20 : index
    %23 = arith.index_cast %arg7 : i32 to index
    %24 = arith.index_cast %arg4 : i32 to index
    %25 = arith.addi %arg5, %c63_i32 : i32
    %26 = arith.divsi %25, %c64_i32 : i32
    %27 = arith.muli %arg7, %c64_i32 : i32
    %28 = arith.index_cast %27 : i32 to index
    %29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %41 = arith.addi %arg18, %18 : index
      %42 = arith.remsi %41, %24 : index
      %43 = arith.subi %41, %42 : index
      %44 = arith.addi %42, %c64 : index
      %45 = arith.minsi %44, %24 : index
      %46 = arith.subi %45, %42 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c64, %46], strides: [%23, %c1] : memref<*xf32> to memref<64x?xf32, strided<[?, ?], offset: ?>>
      %47 = arith.subi %c64, %46 : index
      %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c64, %47], strides: [%23, %c1] : memref<*xf32> to memref<64x?xf32, strided<[?, ?], offset: ?>>
      %48 = arith.remsi %arg17, %20 : index
      %49 = arith.addi %22, %48 : index
      %50 = arith.subi %49, %arg17 : index
      %51 = arith.divsi %50, %20 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c64], strides: [%20, %c1] : memref<*xf32> to memref<?x64xf32, strided<[?, ?], offset: ?>>
      %52 = arith.subi %c32, %51 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c64], strides: [%20, %c1] : memref<*xf32> to memref<?x64xf32, strided<[?, ?], offset: ?>>
      %53 = arith.muli %arg15, %c64_i32 : i32
      %54 = arith.subi %arg5, %53 : i32
      %55 = arith.index_cast %54 : i32 to index
      %56 = arith.minsi %55, %c64 : index
      %alloc = memref.alloc() : memref<32x64xf32>
      %57 = arith.cmpi slt, %56, %c64 : index
      scf.if %57 {
        linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
      }
      %58 = arith.minsi %51, %c32 : index
      %59 = arith.subi %c32, %58 : index
      %subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%58, %56] [1, 1] : memref<?x64xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%59, %56] [1, 1] : memref<?x64xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_6 = memref.subview %alloc[0, 0] [%58, %56] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_7 = memref.subview %alloc[%58, 0] [%59, %56] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %60 = bufferization.to_tensor %alloc restrict writable : memref<32x64xf32>
      %alloc_8 = memref.alloc() : memref<64x64xf32>
      scf.if %57 {
        linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<64x64xf32>)
      }
      %61 = arith.minsi %46, %c64 : index
      %62 = arith.subi %c64, %61 : index
      %subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%56, %61] [1, 1] : memref<64x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%56, %62] [1, 1] : memref<64x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%56, %61] [1, 1] : memref<64x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_12 = memref.subview %alloc_8[0, %61] [%56, %62] [1, 1] : memref<64x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %63 = bufferization.to_tensor %alloc_8 restrict writable : memref<64x64xf32>
      %64 = linalg.matmul ins(%60, %63 : tensor<32x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
      %65 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%64, %arg16 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%64 : tensor<32x64xf32>) {
      ^bb0(%in: f32, %in_13: f32, %out: f32):
        %68 = arith.addf %in, %in_13 : f32
        linalg.yield %68 : f32
      } -> tensor<32x64xf32>
      %66 = arith.addi %arg17, %c64 : index
      %67 = arith.addi %arg18, %28 : index
      scf.yield %65, %66, %67 : tensor<32x64xf32>, index, index
    }
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %16, %30 : index
    %32 = arith.addi %31, %18 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %33 = arith.addi %16, %c32 : index
    %34 = arith.minsi %33, %19 : index
    %35 = arith.subi %34, %16 : index
    %36 = arith.addi %18, %c64 : index
    %37 = arith.minsi %36, %24 : index
    %38 = arith.subi %37, %18 : index
    %39 = arith.minsi %35, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %extracted_slice = tensor.extract_slice %29#0[0, 0] [%39, %40] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
    %subview = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
    return
  }
}

Crash log

No crash. Output incorrect.

Additional information

When running the matmul kernel with the CPU driver with the given block parameters, the output is incorrect (compared to the Pytorch output). A significant part of the output matrix is all zeros while the rest contains the appropriate matmul result.

This kernel is from the python/examples directory. Given the default block parameters, the output is correct. Upon adjusting the parameters, however, it is easy to pick parameters that produce the incorrect behavior.

The given IR is triton-shared IR.

nhat-nguyen commented 1 month ago

Thanks @ajwaitz for reporting the bug! We will try to get to this. @parsifal-47 would this be something you would be interested in taking a quick look at? Thanks!

parsifal-47 commented 1 month ago

@nhat-nguyen yes, would be happy to take a look, thank you!

parsifal-47 commented 1 month ago

@ajwaitz sorry for the delay, I tried modifying test_matmul.py with your block sizes and it passed, after that, I tried to copy-paste "Triton python code" from beginning of this issue and it also works, I see "✅ Triton and Torch match" message and I do not see zeroes in the print. What am I missing?