Cambricon / triton-linalg

Development repository for the Triton-Linalg conversion
Apache License 2.0
131 stars 10 forks source link

error: failed to legalize operation 'tt.load' that was explicitly marked illegal #12

Closed lordrebel closed 2 months ago

lordrebel commented 2 months ago

hi!
I try to use triton python demo to generate ttir and compare triton-shared and triton-linalg,when i use the ttir generated from triton/python/tutorials/03-matrix-multiplication.py I got an error:

>>../build/cmake.linux-x86_64-cpython-3.10/third_party/triton_linalg/bin/triton-linalg-opt --convert-triton-to-linalg matmul.ttir.ir -o matmul_triton_linalg.mlir
/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py:304:20: error: failed to legalize operation 'tt.load' that was explicitly marked illegal
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
                   ^
/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py:304:20: note: see current operation: %184 = "tt.load"(<<UNKNOWN SSA VALUE>>, %183, %13) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi1>, tensor<32x32xf16>) -> tensor<32x32xf16>
triton-linalg-opt: /home/runner/work/triton/triton/llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: ../build/cmake.linux-x86_64-cpython-3.10/third_party/triton_linalg/bin/triton-linalg-opt --convert-triton-to-linalg matmul.ttir.ir -o matmul_triton_linalg.mlir
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):

the ttir i used is:

#loc = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0)
module {
  tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":249:0)) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32> loc(#loc1)
    %c63_i32 = arith.constant 63 : i32 loc(#loc1)
    %c31_i32 = arith.constant 31 : i32 loc(#loc1)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> loc(#loc1)
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf16> loc(#loc1)
    %c1_i32 = arith.constant 1 : i32 loc(#loc1)
    %c0_i32 = arith.constant 0 : i32 loc(#loc1)
    %cst_2 = arith.constant dense<32> : tensor<32x32xi32> loc(#loc1)
    %c64_i32 = arith.constant 64 : i32 loc(#loc1)
    %c32_i32 = arith.constant 32 : i32 loc(#loc1)
    %c8_i32 = arith.constant 8 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.addi %arg3, %c31_i32 : i32 loc(#loc58)
    %2 = arith.divsi %1, %c32_i32 : i32 loc(#loc59)
    %3 = arith.addi %arg4, %c63_i32 : i32 loc(#loc60)
    %4 = arith.divsi %3, %c64_i32 : i32 loc(#loc61)
    %5 = arith.muli %4, %c8_i32 : i32 loc(#loc7)
    %6 = arith.divsi %0, %5 : i32 loc(#loc8)
    %7 = arith.muli %6, %c8_i32 : i32 loc(#loc9)
    %8 = arith.subi %2, %7 : i32 loc(#loc10)
    %9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11)
    %10 = arith.remsi %0, %5 : i32 loc(#loc12)
    %11 = arith.remsi %10, %9 : i32 loc(#loc13)
    %12 = arith.addi %7, %11 : i32 loc(#loc14)
    %13 = arith.divsi %10, %9 : i32 loc(#loc15)
    %14 = arith.muli %12, %c32_i32 : i32 loc(#loc16)
    %15 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> loc(#loc17)
    %16 = tt.splat %14 : i32 -> tensor<32xi32> loc(#loc18)
    %17 = arith.addi %16, %15 : tensor<32xi32> loc(#loc18)
    %18 = tt.splat %arg3 : i32 -> tensor<32xi32> loc(#loc19)
    %19 = arith.remsi %17, %18 : tensor<32xi32> loc(#loc19)
    %20 = arith.muli %13, %c64_i32 : i32 loc(#loc20)
    %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc21)
    %22 = tt.splat %20 : i32 -> tensor<64xi32> loc(#loc22)
    %23 = arith.addi %22, %21 : tensor<64xi32> loc(#loc22)
    %24 = tt.splat %arg4 : i32 -> tensor<64xi32> loc(#loc23)
    %25 = arith.remsi %23, %24 : tensor<64xi32> loc(#loc23)
    %26 = tt.expand_dims %19 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> loc(#loc24)
    %27 = tt.splat %arg6 : i32 -> tensor<32x1xi32> loc(#loc25)
    %28 = arith.muli %26, %27 : tensor<32x1xi32> loc(#loc25)
    %29 = tt.expand_dims %15 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> loc(#loc26)
    %30 = tt.broadcast %28 : tensor<32x1xi32> -> tensor<32x32xi32> loc(#loc27)
    %31 = tt.broadcast %29 : tensor<1x32xi32> -> tensor<32x32xi32> loc(#loc27)
    %32 = arith.addi %30, %31 : tensor<32x32xi32> loc(#loc27)
    %33 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>> loc(#loc28)
    %34 = tt.addptr %33, %32 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32> loc(#loc28)
    %35 = tt.expand_dims %15 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> loc(#loc29)
    %36 = tt.splat %arg7 : i32 -> tensor<32x1xi32> loc(#loc30)
    %37 = arith.muli %35, %36 : tensor<32x1xi32> loc(#loc30)
    %38 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc31)
    %39 = tt.broadcast %37 : tensor<32x1xi32> -> tensor<32x64xi32> loc(#loc32)
    %40 = tt.broadcast %38 : tensor<1x64xi32> -> tensor<32x64xi32> loc(#loc32)
    %41 = arith.addi %39, %40 : tensor<32x64xi32> loc(#loc32)
    %42 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>> loc(#loc33)
    %43 = tt.addptr %42, %41 : tensor<32x64x!tt.ptr<f16>>, tensor<32x64xi32> loc(#loc33)
    %44 = arith.addi %arg5, %c31_i32 : i32 loc(#loc62)
    %45 = arith.divsi %44, %c32_i32 : i32 loc(#loc63)
    %46 = arith.muli %arg7, %c32_i32 : i32 loc(#loc35)
    %47 = tt.splat %46 : i32 -> tensor<32x64xi32> loc(#loc36)
    %48:3 = scf.for %arg9 = %c0_i32 to %45 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %34, %arg12 = %43) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>)  : i32 {
      %66 = arith.muli %arg9, %c32_i32 : i32 loc(#loc38)
      %67 = arith.subi %arg5, %66 : i32 loc(#loc39)
      %68 = tt.splat %67 : i32 -> tensor<1x32xi32> loc(#loc40)
      %69 = arith.cmpi slt, %29, %68 : tensor<1x32xi32> loc(#loc40)
      %70 = tt.broadcast %69 : tensor<1x32xi1> -> tensor<32x32xi1> loc(#loc41)
      %71 = tt.load %arg11, %70, %cst_1 : tensor<32x32x!tt.ptr<f16>> loc(#loc41)
      %72 = tt.splat %67 : i32 -> tensor<32x1xi32> loc(#loc42)
      %73 = arith.cmpi slt, %35, %72 : tensor<32x1xi32> loc(#loc42)
      %74 = tt.broadcast %73 : tensor<32x1xi1> -> tensor<32x64xi1> loc(#loc43)
      %75 = tt.load %arg12, %74, %cst_0 : tensor<32x64x!tt.ptr<f16>> loc(#loc43)
      %76 = tt.dot %71, %75, %arg10, inputPrecision = tf32 : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32> loc(#loc44)
      %77 = tt.addptr %arg11, %cst_2 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32> loc(#loc45)
      %78 = tt.addptr %arg12, %47 : tensor<32x64x!tt.ptr<f16>>, tensor<32x64xi32> loc(#loc36)
      scf.yield %76, %77, %78 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>> loc(#loc46)
    } loc(#loc37)
    %49 = arith.truncf %48#0 : tensor<32x64xf32> to tensor<32x64xf16> loc(#loc47)
    %50 = tt.expand_dims %17 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> loc(#loc48)
    %51 = tt.splat %arg8 : i32 -> tensor<32x1xi32> loc(#loc49)
    %52 = arith.muli %51, %50 : tensor<32x1xi32> loc(#loc49)
    %53 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>> loc(#loc50)
    %54 = tt.addptr %53, %52 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32> loc(#loc50)
    %55 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc51)
    %56 = tt.broadcast %54 : tensor<32x1x!tt.ptr<f16>> -> tensor<32x64x!tt.ptr<f16>> loc(#loc52)
    %57 = tt.broadcast %55 : tensor<1x64xi32> -> tensor<32x64xi32> loc(#loc52)
    %58 = tt.addptr %56, %57 : tensor<32x64x!tt.ptr<f16>>, tensor<32x64xi32> loc(#loc52)
    %59 = tt.splat %arg3 : i32 -> tensor<32x1xi32> loc(#loc53)
    %60 = arith.cmpi slt, %50, %59 : tensor<32x1xi32> loc(#loc53)
    %61 = tt.splat %arg4 : i32 -> tensor<1x64xi32> loc(#loc54)
    %62 = arith.cmpi slt, %55, %61 : tensor<1x64xi32> loc(#loc54)
    %63 = tt.broadcast %60 : tensor<32x1xi1> -> tensor<32x64xi1> loc(#loc55)
    %64 = tt.broadcast %62 : tensor<1x64xi1> -> tensor<32x64xi1> loc(#loc55)
    %65 = arith.andi %63, %64 : tensor<32x64xi1> loc(#loc55)
    tt.store %58, %49, %65 : tensor<32x64x!tt.ptr<f16>> loc(#loc56)
    tt.return loc(#loc57)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":272:24)
#loc3 = loc("/workspace/projs/triton_shared/triton/python/triton/language/standard.py":44:22)
#loc4 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":273:27)
#loc5 = loc("/workspace/projs/triton_shared/triton/python/triton/language/standard.py":44:28)
#loc6 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":274:27)
#loc7 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":275:38)
#loc8 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":276:22)
#loc9 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":277:29)
#loc10 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":278:35)
#loc11 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":278:48)
#loc12 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":279:34)
#loc13 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":279:54)
#loc14 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":279:27)
#loc15 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":280:40)
#loc16 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":289:23)
#loc17 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":289:51)
#loc18 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":289:38)
#loc19 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":289:68)
#loc20 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":290:23)
#loc21 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":290:51)
#loc22 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":290:38)
#loc23 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":290:68)
#loc24 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":292:30)
#loc25 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":292:41)
#loc26 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":292:60)
#loc27 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":292:53)
#loc28 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":292:22)
#loc29 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":293:29)
#loc30 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":293:40)
#loc31 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":293:60)
#loc32 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":293:52)
#loc33 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":293:22)
#loc34 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":301:33)
#loc35 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":310:33)
#loc36 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":310:18)
#loc37 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":301:22)
#loc38 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":304:59)
#loc39 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":304:55)
#loc40 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":304:51)
#loc41 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":304:20)
#loc42 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":305:51)
#loc43 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":305:20)
#loc44 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":307:35)
#loc45 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":309:18)
#loc46 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":310:8)
#loc47 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":315:23)
#loc48 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":321:41)
#loc49 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":321:33)
#loc50 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":321:21)
#loc51 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":321:72)
#loc52 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":321:52)
#loc53 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":322:33)
#loc54 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":322:58)
#loc55 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":322:39)
#loc56 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":323:21)
#loc57 = loc("/workspace/projs/triton_shared/triton/python/tutorials/03-matrix-multiplication.py":323:4)
#loc58 = loc(callsite(#loc3 at #loc4))
#loc59 = loc(callsite(#loc5 at #loc4))
#loc60 = loc(callsite(#loc3 at #loc6))
#loc61 = loc(callsite(#loc5 at #loc6))
#loc62 = loc(callsite(#loc3 at #loc34))
#loc63 = loc(callsite(#loc5 at #loc34))

which is generated from triton(263fb70281f40cdb44768ae0cbf016181860c409)

hesse-x commented 2 months ago

If you want to directly convert TTIr to Linalg IR, you need to use the triton-to-linalg pipeline. The convert-triton-to-linalg is a standalone pass, but for a complete conversion, all the passes in the repository must be executed in a specific order. During the execution of convert-triton-to-linalg, some preprocessing passes are required, which is why an error occurs. Therefore, you need to use triton-linalg-opt --triton-to-linalg tt.mlir to complete the conversion. If you want to see the full sequence of pass calls in the pipeline, you can add --mlir-print-ir-after-all to view which passes are executed and their results during the pipeline execution.

lordrebel commented 2 months ago

i will try

lordrebel commented 2 months ago

thx