microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
132 stars 26 forks source link

[Bug]: tl.reduce fp16 compatibility issues #139

Open mitekoth opened 4 weeks ago

mitekoth commented 4 weeks ago

Triton python code

#Scenario 1
@triton.jit
def fn(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    denom = tl.sum(x, axis=0)
    y = x / denom
    tl.store(output_ptr + offsets, y)

#Scenario 2
@triton.jit
def fn(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    denom = tl.sum(x, axis=0)
    y = x - denom
    tl.store(output_ptr + offsets, y)

Triton IR

#Scenario 1
module {
  tt.func public @fn(%arg0: !tt.ptr<f16> , %arg1: !tt.ptr<f16> ) attributes {noinline = false} {
    %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> 
    %1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>> 
    %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr<f16>>, tensor<32xi32> 
    %3 = tt.load %2 : tensor<32x!tt.ptr<f16>> 
    %4 = "tt.reduce"(%3) <{axis = 0 : i32}> ({
    ^bb0(%arg2: f16 , %arg3: f16 ):
      %12 = arith.addf %arg2, %arg3 : f16 
      tt.reduce.return %12 : f16 
    }) : (tensor<32xf16>) -> f16 
    %5 = arith.extf %3 : tensor<32xf16> to tensor<32xf32> 
    %6 = arith.extf %4 : f16 to f32 
    %7 = tt.splat %6 : f32 -> tensor<32xf32> 
    %8 = arith.divf %5, %7 : tensor<32xf32> 
    %9 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>> 
    %10 = tt.addptr %9, %0 : tensor<32x!tt.ptr<f16>>, tensor<32xi32> 
    %11 = arith.truncf %8 : tensor<32xf32> to tensor<32xf16> 
    tt.store %10, %11 : tensor<32x!tt.ptr<f16>> 
    tt.return 
  } 
}

#Scenario 2
module {
  tt.func public @fn(%arg0: !tt.ptr<f16> , %arg1: !tt.ptr<f16> ) attributes {noinline = false} {
    %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>> 
    %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr<f16>>, tensor<32xi32> 
    %3 = tt.load %2 : tensor<32x!tt.ptr<f16>> 
    %4 = "tt.reduce"(%3) <{axis = 0 : i32}> ({
    ^bb0(%arg2: f16 , %arg3: f16 ):
      %9 = arith.addf %arg2, %arg3 : f16 
      tt.reduce.return %9 : f16 
    }) : (tensor<32xf16>) -> f16 
    %5 = tt.splat %4 : f16 -> tensor<32xf16> 
    %6 = arith.subf %3, %5 : tensor<32xf16> 
    %7 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>> 
    %8 = tt.addptr %7, %0 : tensor<32x!tt.ptr<f16>>, tensor<32xi32> 
    tt.store %8, %6 : tensor<32x!tt.ptr<f16>> 
    tt.return 
  } 
}

Crash log

#Scenario 1
triton-shared-opt --triton-to-linalg-experimental res.ttir 
bin/res.ttir:7:10: error: failed to materialize conversion for result #0 of operation 'tt.reduce' that remained live after conversion
    %4 = "tt.reduce"(%3) <{axis = 0 : i32}> ({
         ^
res.ttir:7:10: note: see current operation: 
%8 = "tt.reduce"(%1) <{axis = 0 : i32}> ({
^bb0(%arg17: f16, %arg18: f16):
  %25 = "arith.addf"(%arg17, %arg18) <{fastmath = #arith.fastmath<none>}> : (f16, f16) -> f16
  "tt.reduce.return"(%25) : (f16) -> ()
}) : (tensor<32xf16>) -> f16
bin/softmax-alt2.ttir:13:10: note: see existing live user here: %8 = arith.extf %4 : f16 to f32
    %6 = arith.extf %4 : f16 to f32 
         ^
#Scenario 2
<unknown>:0: error: 'linalg.yield' op type of yield operand 1 ('bf16') doesn't match the element type of the enclosing linalg.generic op ('f16')
<unknown>:0: note: see current operation: "linalg.yield"(%arg11) : (bf16) -> ()

Additional information

Triton-shared branch: nhat/dep ( Specifically this commit )

Mentioning two different errors from two different kernels with a minor difference, but in the same ticket as they are both related to fp16 computation issues.

nhat-nguyen commented 2 weeks ago

@haishanzzzz I believe this is similar (or related) to one of the bf16 issues that your team saw right?