microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
142 stars 27 forks source link

Missing Reduction: maxf-variant #11

Open manbearian opened 9 months ago

manbearian commented 9 months ago

created from #7.

A reduction sequence not yet supported; i'm not sure what the original Triton code was, but it looks like some variant of maxf.

repros.zip

triton-shared-opt -triton-to-linalg 15.mlir triton-shared-opt -triton-to-linalg 18.mlir triton-shared-opt -triton-to-linalg 20.mlir triton-shared-opt -triton-to-linalg 22.mlir triton-shared-opt -triton-to-linalg 29.mlir triton-shared-opt -triton-to-linalg 31.mlir triton-shared-opt -triton-to-linalg 39.mlir triton-shared-opt -triton-to-linalg 46.mlir triton-shared-opt -triton-to-linalg 47.mlir triton-shared-opt -triton-to-linalg 58.mlir triton-shared-opt -triton-to-linalg 70.mlir triton-shared-opt -triton-to-linalg 75.mlir

Error output:

+++/home/ianb/test/ttirs_linalg_failed/15.mlir
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: error: Only support lowering reduction with body containing 1 max(i/f) or addf.
    %21 = "tt.reduce"(%20) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: note: see current operation: 
%65 = "tt.reduce"(%64) <{axis = 1 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
  %80 = "arith.cmpf"(%arg11, %arg12) <{predicate = 2 : i64}> : (f32, f32) -> i1
  %81 = "arith.cmpf"(%arg11, %arg11) <{predicate = 13 : i64}> : (f32, f32) -> i1
  %82 = "arith.ori"(%80, %81) : (i1, i1) -> i1
  %83 = "arith.select"(%82, %arg11, %arg12) : (i1, f32, f32) -> f32
  "tt.reduce.return"(%83) : (f32) -> ()
}) : (tensor<16x128xf32>) -> tensor<16xf32>
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: error: failed to legalize operation 'tt.reduce'
    %21 = "tt.reduce"(%20) <{axis = 1 : i32}> ({
          ^
/home/ianb/test/ttirs_linalg_failed/15.mlir:30:11: note: see current operation: 
%65 = "tt.reduce"(%64) <{axis = 1 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
  %80 = "arith.cmpf"(%arg11, %arg12) <{predicate = 2 : i64}> : (f32, f32) -> i1
  %81 = "arith.cmpf"(%arg11, %arg11) <{predicate = 13 : i64}> : (f32, f32) -> i1
  %82 = "arith.ori"(%80, %81) : (i1, i1) -> i1
  %83 = "arith.select"(%82, %arg11, %arg12) : (i1, f32, f32) -> f32
  "tt.reduce.return"(%83) : (f32) -> ()
}) : (tensor<16x128xf32>) -> tensor<16xf32>