microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
142 stars 27 forks source link

Missing Reduction: addi #12

Closed manbearian closed 9 months ago

manbearian commented 9 months ago

created from #7.

A reduction sequence not yet supported: integral sum

repros.zip

triton-shared-opt -triton-to-linalg 30.mlir triton-shared-opt -triton-to-linalg 86.mlir

Error output:

+++/home/ianb/test/ttirs_linalg_failed/30.mlir
/home/ianb/test/ttirs_linalg_failed/30.mlir:24:11: error: Only support lowering reduction with body containing 1 max(i/f) or addf.
    %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/30.mlir:24:11: note: see current operation: 
%76 = "tt.reduce"(%75) <{axis = 1 : i32}> ({
^bb0(%arg12: i64, %arg13: i64):
  %84 = "arith.addi"(%arg12, %arg13) : (i64, i64) -> i64
  "tt.reduce.return"(%84) : (i64) -> ()
}) : (tensor<1x2xi64>) -> tensor<1xi64>
/home/ianb/test/ttirs_linalg_failed/30.mlir:24:11: error: failed to legalize operation 'tt.reduce'
    %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/30.mlir:24:11: note: see current operation: 
%76 = "tt.reduce"(%75) <{axis = 1 : i32}> ({
^bb0(%arg12: i64, %arg13: i64):
  %84 = "arith.addi"(%arg12, %arg13) : (i64, i64) -> i64
  "tt.reduce.return"(%84) : (i64) -> ()
}) : (tensor<1x2xi64>) -> tensor<1xi64>
makslevental commented 9 months ago

The repros here I think intersect with some slightly more complex patterns; E.g., with the natural addition (see https://github.com/microsoft/triton-shared/pull/20) 86.mlir still fails due to

86.mlir:33:11: error: failed to legalize operation 'tt.splat' marked as erased
    %18 = tt.splat %17 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
          ^
86.mlir:33:11: note: see current operation: %93 = "tt.splat"(%92) {MetaUse} : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
86.mlir:34:5: note: found live user of result #0: "memref.tensor_store"(%88, %93) : (tensor<1x1xf32>, tensor<1x1x!tt.ptr<f32, 1>>) -> ()
    tt.store %18, %16 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
    ^

Happy to keep pushing to make the repros pass in totum but would appreciate a hint, since it seems in other tests "tt.splat"(%92) : (!tt.ptr<f32, 1>) lowers fine.