microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
163 stars 34 forks source link

[Bug]: encountered operand produced by an unsupported operation "arith.extsi" #125

Open colawithsauce opened 6 months ago

colawithsauce commented 6 months ago

Triton python code

#!/usr/bin/env python3

import torch

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_with_block_pointers(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See the matrix multiplication tutorial for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create block pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction and accumulate.
    # See above `Make a Block Pointer` section for details.
    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
                                    offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
                                    order=(1, 0))
    b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
                                    offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
                                    order=(1, 0))

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Load with boundary checks, no need to calculate the mask manually.
        # For better performance, you may remove some axis from the boundary
        # check, if you can guarantee that the access is always in-bound in
        # that axis.
        # See above `Load/Store a Block Pointer` section for details.
        a = tl.load(a_block_ptr, boundary_check=(0, 1))
        b = tl.load(b_block_ptr, boundary_check=(0, 1))
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the block pointer to the next K block.
        # See above `Advance a Block Pointer` section for details.
        a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
        b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
    c = accumulator.to(tl.float16)

    # ----------------------------------------------------------------
    # Write back the block of the output matrix C with boundary checks.
    # See above `Load/Store a Block Pointer` section for details.
    c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
                                    offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
                                    block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
    tl.store(c_block_ptr, c, boundary_check=(0, 1))

# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    compiled_kernel = matmul_kernel_with_block_pointers[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1))

    print(compiled_kernel.asm['ttir'])
    return c

if __name__ == '__main__':
    torch.manual_seed(0)
    a = torch.randn((2048, 2048), device='cuda', dtype=torch.float16)
    b = torch.randn((2048, 2048), device='cuda', dtype=torch.float16)
    matmul(a, b)

Triton IR

module {
  tt.func public @matmul_kernel_with_block_pointers_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c32_i64 = arith.constant 32 : i64
    %cst = arith.constant dense<0> : tensor<1x128xi64>
    %cst_0 = arith.constant dense<0> : tensor<32x1xi64>
    %cst_1 = arith.constant dense<0> : tensor<1x32xi64>
    %cst_2 = arith.constant dense<0> : tensor<64x1xi64>
    %c0_i64 = arith.constant 0 : i64
    %c127_i32 = arith.constant 127 : i32
    %c63_i32 = arith.constant 63 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x128xf32>
    %c32_i32 = arith.constant 32 : i32
    %c128_i32 = arith.constant 128 : i32
    %c0_i32 = arith.constant 0 : i32
    %c64_i32 = arith.constant 64 : i32
    %c8_i32 = arith.constant 8 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.addi %arg3, %c63_i32 : i32
    %2 = arith.divsi %1, %c64_i32 : i32
    %3 = arith.addi %arg4, %c127_i32 : i32
    %4 = arith.divsi %3, %c128_i32 : i32
    %5 = arith.muli %4, %c8_i32 : i32
    %6 = arith.divsi %0, %5 : i32
    %7 = arith.muli %6, %c8_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c8_i32 : i32
    %10 = arith.remsi %0, %9 : i32
    %11 = arith.addi %7, %10 : i32
    %12 = arith.remsi %0, %5 : i32
    %13 = arith.divsi %12, %9 : i32
    %14 = arith.muli %11, %c64_i32 : i32
    %15 = arith.extsi %arg3 : i32 to i64
    %16 = arith.extsi %arg5 : i32 to i64
    %17 = arith.extsi %arg6 : i32 to i64
    %18 = arith.extsi %14 : i32 to i64
    %19 = arith.muli %13, %c128_i32 : i32
    %20 = arith.extsi %arg4 : i32 to i64
    %21 = arith.extsi %arg7 : i32 to i64
    %22 = arith.extsi %19 : i32 to i64
    %23 = tt.splat %arg0 : !tt.ptr<f16, 1> -> tensor<64x32x!tt.ptr<f16, 1>>
    %24 = tt.splat %18 : i64 -> tensor<64xi64>
    %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %26 = arith.extsi %25 : tensor<64xi32> to tensor<64xi64>
    %27 = arith.addi %24, %26 : tensor<64xi64>
    %28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %29 = tt.splat %17 : i64 -> tensor<64x1xi64>
    %30 = arith.muli %28, %29 : tensor<64x1xi64>
    %31 = tt.broadcast %30 : tensor<64x1xi64> -> tensor<64x32xi64>
    %32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %33 = arith.extsi %32 : tensor<32xi32> to tensor<32xi64>
    %34 = arith.cmpi sge, %28, %cst_2 : tensor<64x1xi64>
    %35 = tt.splat %15 : i64 -> tensor<64x1xi64>
    %36 = arith.cmpi slt, %28, %35 : tensor<64x1xi64>
    %37 = arith.andi %34, %36 : tensor<64x1xi1>
    %38 = tt.broadcast %37 : tensor<64x1xi1> -> tensor<64x32xi1>
    %39 = tt.splat %16 : i64 -> tensor<1x32xi64>
    %40 = tt.splat %arg1 : !tt.ptr<f16, 1> -> tensor<32x128x!tt.ptr<f16, 1>>
    %41 = tt.splat %21 : i64 -> tensor<32x1xi64>
    %42 = tt.splat %22 : i64 -> tensor<128xi64>
    %43 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %44 = arith.extsi %43 : tensor<128xi32> to tensor<128xi64>
    %45 = arith.addi %42, %44 : tensor<128xi64>
    %46 = tt.expand_dims %45 {axis = 0 : i32} : tensor<128xi64> -> tensor<1x128xi64>
    %47 = tt.broadcast %46 : tensor<1x128xi64> -> tensor<32x128xi64>
    %48 = tt.splat %16 : i64 -> tensor<32x1xi64>
    %49 = arith.cmpi sge, %46, %cst : tensor<1x128xi64>
    %50 = tt.splat %20 : i64 -> tensor<1x128xi64>
    %51 = arith.cmpi slt, %46, %50 : tensor<1x128xi64>
    %52 = arith.andi %49, %51 : tensor<1x128xi1>
    %53 = tt.broadcast %52 : tensor<1x128xi1> -> tensor<32x128xi1>
    %54:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst_3, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<64x128xf32>, i64, i64)  : i32 {
      %85 = tt.splat %arg11 : i64 -> tensor<32xi64>
      %86 = arith.addi %85, %33 : tensor<32xi64>
      %87 = tt.expand_dims %86 {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64>
      %88 = tt.broadcast %87 : tensor<1x32xi64> -> tensor<64x32xi64>
      %89 = arith.addi %31, %88 : tensor<64x32xi64>
      %90 = tt.addptr %23, %89 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
      %91 = arith.cmpi sge, %87, %cst_1 : tensor<1x32xi64>
      %92 = arith.cmpi slt, %87, %39 : tensor<1x32xi64>
      %93 = arith.andi %91, %92 : tensor<1x32xi1>
      %94 = tt.broadcast %93 : tensor<1x32xi1> -> tensor<64x32xi1>
      %95 = arith.andi %38, %94 : tensor<64x32xi1>
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32x!tt.ptr<f16, 1>>
      %97 = tt.splat %arg12 : i64 -> tensor<32xi64>
      %98 = arith.addi %97, %33 : tensor<32xi64>
      %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
      %100 = arith.muli %99, %41 : tensor<32x1xi64>
      %101 = tt.broadcast %100 : tensor<32x1xi64> -> tensor<32x128xi64>
      %102 = arith.addi %101, %47 : tensor<32x128xi64>
      %103 = tt.addptr %40, %102 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
      %104 = arith.cmpi sge, %99, %cst_0 : tensor<32x1xi64>
      %105 = arith.cmpi slt, %99, %48 : tensor<32x1xi64>
      %106 = arith.andi %104, %105 : tensor<32x1xi1>
      %107 = tt.broadcast %106 : tensor<32x1xi1> -> tensor<32x128xi1>
      %108 = arith.andi %107, %53 : tensor<32x128xi1>
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr<f16, 1>>
      %110 = tt.dot %96, %109, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32>
      %111 = arith.addi %arg11, %c32_i64 : i64
      %112 = arith.addi %arg12, %c32_i64 : i64
      scf.yield %110, %111, %112 : tensor<64x128xf32>, i64, i64
    }
    %55 = arith.truncf %54#0 : tensor<64x128xf32> to tensor<64x128xf16>
    %56 = arith.extsi %arg8 : i32 to i64
    %57 = tt.splat %arg2 : !tt.ptr<f16, 1> -> tensor<64x128x!tt.ptr<f16, 1>>
    %58 = tt.splat %18 : i64 -> tensor<64xi64>
    %59 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %60 = arith.extsi %59 : tensor<64xi32> to tensor<64xi64>
    %61 = arith.addi %58, %60 : tensor<64xi64>
    %62 = tt.expand_dims %61 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %63 = tt.splat %56 : i64 -> tensor<64x1xi64>
    %64 = arith.muli %62, %63 : tensor<64x1xi64>
    %65 = tt.broadcast %64 : tensor<64x1xi64> -> tensor<64x128xi64>
    %66 = tt.splat %22 : i64 -> tensor<128xi64>
    %67 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %68 = arith.extsi %67 : tensor<128xi32> to tensor<128xi64>
    %69 = arith.addi %66, %68 : tensor<128xi64>
    %70 = tt.expand_dims %69 {axis = 0 : i32} : tensor<128xi64> -> tensor<1x128xi64>
    %71 = tt.broadcast %70 : tensor<1x128xi64> -> tensor<64x128xi64>
    %72 = arith.addi %65, %71 : tensor<64x128xi64>
    %73 = tt.addptr %57, %72 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
    %74 = arith.cmpi sge, %62, %cst_2 : tensor<64x1xi64>
    %75 = tt.splat %15 : i64 -> tensor<64x1xi64>
    %76 = arith.cmpi slt, %62, %75 : tensor<64x1xi64>
    %77 = arith.andi %74, %76 : tensor<64x1xi1>
    %78 = tt.broadcast %77 : tensor<64x1xi1> -> tensor<64x128xi1>
    %79 = arith.cmpi sge, %70, %cst : tensor<1x128xi64>
    %80 = tt.splat %20 : i64 -> tensor<1x128xi64>
    %81 = arith.cmpi slt, %70, %80 : tensor<1x128xi64>
    %82 = arith.andi %79, %81 : tensor<1x128xi1>
    %83 = tt.broadcast %82 : tensor<1x128xi1> -> tensor<64x128xi1>
    %84 = arith.andi %78, %83 : tensor<64x128xi1>
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128x!tt.ptr<f16, 1>>
    tt.return
  }
}

Crash log

[16:20:18] colawithsauce@SPQR /home/colawithsauce/Projects/Triton
> triton-shared-opt --triton-to-linalg ~/playground/test.mlir
%43 = "arith.extsi"(%42) {MetaUse} : (tensor<64xi32>) -> tensor<64xi64>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:701!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: triton-shared-opt --triton-to-linalg /home/colawithsauce/playground/test.mlir
 #0 0x00005f9befc0b087 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x242a087)
 #1 0x00005f9befc08bae llvm::sys::RunSignalHandlers() (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2427bae)
 #2 0x00005f9befc0b73f SignalHandler(int) Signals.cpp:0:0
 #3 0x000077438ec54eb0 __restore_rt (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3deb0)
 #4 0x000077438eca407c __pthread_kill_implementation (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x8d07c)
 #5 0x000077438ec54e06 gsignal (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3de06)
 #6 0x000077438ec3d8f5 abort (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x268f5)
 #7 0x00005f9befbccfc1 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebfc1)
 #8 0x00005f9bedc26dc0 mlir::triton::PtrAnalysis::visitOperandMul(mlir::arith::MulIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:369:0
 #9 0x00005f9bedc26780 llvm::SmallVectorBase<unsigned int>::size() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:91:32
#10 0x00005f9bedc26780 mlir::triton::PtrState::getRank() const /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:61:3
#11 0x00005f9bedc26780 mlir::triton::PtrAnalysis::visitOperandAdd(mlir::arith::AddIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:358:17
#12 0x00005f9bedc26af6 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#13 0x00005f9bedc27928 mlir::triton::PtrAnalysis::visitOperandExpandDims(mlir::triton::ExpandDimsOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:461:20
#14 0x00005f9bedc26ca5 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#15 0x00005f9bedc26e74 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:143:46
#16 0x00005f9bedc26e74 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:148:49
#17 0x00005f9bedc26e74 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:500:42
#18 0x00005f9bedc26e74 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:601:9
#19 0x00005f9bedc26e74 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:1211:19
#20 0x00005f9bedc26e74 mlir::triton::PtrState::PtrState() /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Analysis/PtrAnalysis.h:46:7
#21 0x00005f9bedc26e74 mlir::triton::PtrAnalysis::visitOperandMul(mlir::arith::MulIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:373:12
#22 0x00005f9bedc26b77 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#23 0x00005f9bedc27d0b mlir::triton::PtrAnalysis::visitOperandBroadcast(mlir::triton::BroadcastOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:497:24
#24 0x00005f9bedc26c27 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#25 0x00005f9bedc266f4 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:143:46
#26 0x00005f9bedc266f4 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:148:49
#27 0x00005f9bedc266f4 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:500:42
#28 0x00005f9bedc266f4 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:601:9
#29 0x00005f9bedc266f4 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:1211:19
#30 0x00005f9bedc266f4 mlir::triton::PtrState::PtrState() /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Analysis/PtrAnalysis.h:46:7
#31 0x00005f9bedc266f4 mlir::triton::PtrAnalysis::visitOperandAdd(mlir::arith::AddIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:355:12
#32 0x00005f9bedc26af6 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#33 0x00005f9bedc28e6e mlir::Value::operator bool() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/IR/Value.h:120:43
#34 0x00005f9bedc28e6e mlir::triton::PtrAnalysis::visitOperandAddptr(mlir::triton::AddPtrOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:588:3
#35 0x00005f9bedc299d8 mlir::triton::PtrAnalysis::rewriteAddptrOp(mlir::triton::AddPtrOp, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:749:26
#36 0x00005f9bedcd6767 llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>::~SmallDenseMap() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/DenseMap.h:960:11
#37 0x00005f9bedcd6767 (anonymous namespace)::LegacyAddPtrConverter::matchAndRewrite(mlir::triton::AddPtrOp, mlir::triton::AddPtrOpAdaptor, mlir::ConversionPatternRewriter&) const /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:242:3
#38 0x00005f9bedc645af mlir::OpConversionPattern<mlir::triton::AddPtrOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Transforms/DialectConversion.h:538:3
#39 0x00005f9bef7e4d50 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2003d50)
#40 0x00005f9bef81e194 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const PatternApplicator.cpp:0:0
#41 0x00005f9bef81acaf mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2039caf)
#42 0x00005f9bef7f2315 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#43 0x00005f9bef7e7f44 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0
#44 0x00005f9bef7eb1e0 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*, void>>*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x200a1e0)
#45 0x00005f9bedcec085 (anonymous namespace)::TritonToLinalgPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp:184:16
#46 0x00005f9bef0cc6a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#47 0x00005f9bef0cce41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#48 0x00005f9bef0cf27b mlir::PassManager::run(mlir::Operation*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ee27b)
#49 0x00005f9bef0c959f performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#50 0x00005f9bef0c87bd mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#51 0x00005f9befb99f89 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23b8f89)
#52 0x00005f9bef0c3bfa mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2bfa)
#53 0x00005f9bef0c3d96 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2d96)
#54 0x00005f9bef0c4146 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e3146)
#55 0x00005f9bedd79bcb main /home/colawithsauce/Projects/Triton/triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#56 0x000077438ec3f0ce __libc_start_call_main (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x280ce)
#57 0x000077438ec3f189 __libc_start_main@GLIBC_2.2.5 (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x28189)
#58 0x00005f9bedaadc75 _start (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2ccc75)
fish: Job 1, 'triton-shared-opt --triton-to-l…' terminated by signal SIGABRT (Abort)

Additional information

colawithsauce commented 6 months ago

I used triton-shared-opt --triton-to-linalg-experimental instead, here is the new error information:

[16:48:00] colawithsauce@SPQR /home/colawithsauce/Projects/Triton/triton_shared
> triton-opt ~/playground/test.mlir| triton-shared-opt --triton-to-linalg-experimental
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%27 = arith.extsi %26 : tensor<64xi32> to tensor<64xi64>
<stdin>:77:13: remark: PtrAnalysis: Failed to rewrite AddPtrOp
      %90 = tt.addptr %23, %89 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
            ^
<stdin>:77:13: note: see current operation: %92 = tt.addptr %24, %91 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
<stdin>:83:13: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
            ^
<stdin>:83:13: note: see current operation: %98 = tt.load %92, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
<stdin>:83:13: remark: PtrAnalysis: Failed to rewrite LoadOp
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
            ^
<stdin>:83:13: note: see current operation: %98 = tt.load %92, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%34 = arith.extsi %33 : tensor<32xi32> to tensor<32xi64>
<stdin>:90:14: remark: PtrAnalysis: Failed to rewrite AddPtrOp
      %103 = tt.addptr %40, %102 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
             ^
<stdin>:90:14: note: see current operation: %106 = tt.addptr %41, %104 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
<stdin>:96:14: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
             ^
<stdin>:96:14: note: see current operation: %112 = tt.load %106, %111 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
<stdin>:96:14: remark: PtrAnalysis: Failed to rewrite LoadOp
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
             ^
<stdin>:96:14: note: see current operation: %112 = tt.load %106, %111 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%62 = arith.extsi %61 : tensor<64xi32> to tensor<64xi64>
<stdin>:120:11: remark: PtrAnalysis: Failed to rewrite AddPtrOp
    %73 = tt.addptr %57, %72 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
          ^
<stdin>:120:11: note: see current operation: %75 = tt.addptr %59, %74 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
<stdin>:132:5: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so storeOp cannot be rewritten
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
    ^
<stdin>:132:5: note: see current operation: tt.store %75, %57, %86 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
<stdin>:132:5: remark: PtrAnalysis: Failed to rewrite StoreOp
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
    ^
<stdin>:132:5: note: see current operation: tt.store %75, %57, %86 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
LLVM ERROR: Failed to infer result type(s).
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: triton-shared-opt --triton-to-linalg-experimental
 #0 0x00006282cf875087 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x242a087)
 #1 0x00006282cf872bae llvm::sys::RunSignalHandlers() (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2427bae)
 #2 0x00006282cf87573f SignalHandler(int) Signals.cpp:0:0
 #3 0x000077b037e54eb0 __restore_rt (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3deb0)
 #4 0x000077b037ea407c __pthread_kill_implementation (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x8d07c)
 #5 0x000077b037e54e06 gsignal (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3de06)
 #6 0x000077b037e3d8f5 abort (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x268f5)
 #7 0x00006282cf836d00 llvm::report_fatal_error(llvm::Twine const&, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebd00)
 #8 0x00006282cf836b18 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebb18)
 #9 0x00006282ce3f05b1 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0xfa55b1)
#10 0x00006282cd97a18f mlir::memref::ExtractStridedMetadataOp mlir::OpBuilder::create<mlir::memref::ExtractStridedMetadataOp, mlir::Value&>(mlir::Location, mlir::Value&) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/IR/Builders.h:509:16
#11 0x00006282cd978648 (anonymous namespace)::ScalarAddptrConverter::matchAndRewrite(mlir::triton::AddPtrOp, mlir::triton::AddPtrOpAdaptor, mlir::ConversionPatternRewriter&) const /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp:766:18
#12 0x00006282cd8ce5af mlir::OpConversionPattern<mlir::triton::AddPtrOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Transforms/DialectConversion.h:538:3
#13 0x00006282cf44ed50 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2003d50)
#14 0x00006282cf488194 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const PatternApplicator.cpp:0:0
#15 0x00006282cf484caf mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2039caf)
#16 0x00006282cf45c315 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#17 0x00006282cf451f44 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0
#18 0x00006282cf4551e0 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*, void>>*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x200a1e0)
#19 0x00006282cd97b2d3 (anonymous namespace)::StructuredToMemrefPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:141:16
#20 0x00006282ced366a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#21 0x00006282ced36e41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#22 0x00006282ced3b4e8 mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_6>(long, mlir::OpPassManager&, mlir::Operation*) Pass.cpp:0:0
#23 0x00006282cd961389 mlir::LogicalResult::failed() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Support/LogicalResult.h:44:33
#24 0x00006282cd961389 mlir::failed(mlir::LogicalResult) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Support/LogicalResult.h:72:58
#25 0x00006282cd961389 (anonymous namespace)::TritonToLinalgExperimentalPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp:55:9
#26 0x00006282ced366a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#27 0x00006282ced36e41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#28 0x00006282ced3927b mlir::PassManager::run(mlir::Operation*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ee27b)
#29 0x00006282ced3359f performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#30 0x00006282ced327bd mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#31 0x00006282cf803f89 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23b8f89)
#32 0x00006282ced2dbfa mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2bfa)
#33 0x00006282ced2dd96 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2d96)
#34 0x00006282ced2e146 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e3146)
#35 0x00006282cd9e3bcb main /home/colawithsauce/Projects/Triton/triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#36 0x000077b037e3f0ce __libc_start_call_main (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x280ce)
#37 0x000077b037e3f189 __libc_start_main@GLIBC_2.2.5 (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x28189)
#38 0x00006282cd717c75 _start (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2ccc75)
fish: Process 424253, 'triton-shared-opt' from job 1, 'triton-opt ~/playground/test.ml…' terminated by signal SIGABRT (Abort)
nhat-nguyen commented 6 months ago

Thank you for the report for both the legacy and experimental passes! This looks like a matmul kernel -- perhaps from the official triton tutorial? I will take a closer look. It has been a while since we last updated the tutorial kernel tests. But if this is the official one, it's definitely important that we fix it.

colawithsauce commented 6 months ago

@nhat-nguyen yes, it is indeed from triton tutorial, I only did some changes on the kernel's caller, and left the kernel un-modified.

colawithsauce commented 6 months ago

@nhat-nguyen, this code is from triton's tutorial 08-experimental-block-pointer. However, today when I going to search this tutorial page on triton's online document, I got 404 not found. I didn't konw when it had been deleted. here is the wayback machine link to it

I search block pointer on triton's issue page, I found this comment from its contributor: https://github.com/openai/triton/issues/1946#issuecomment-1636500797

I guess block_ptr is currently still work in progress, and usage of it are not recommended.

Thanks for your attention!