microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
165 stars 34 forks source link

[Bug]: failed to legalize operation 'tt.splat' marked as erased #65

Open yuanfz98 opened 10 months ago

yuanfz98 commented 10 months ago

Triton python code

def triton_(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 2
    RBLOCK: tl.constexpr = 2
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r0 = rindex
    tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0)
    tmp5 = tl.load(in_ptr1 + (r0), rmask, other=0)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp3 = tl.where(rmask, tmp1, 0)
    tmp4 = tl.sum(tmp3, 1)[:, None]
    tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
    tmp8 = tl.where(rmask, tmp6, 0)
    tmp9 = tl.sum(tmp8, 1)[:, None]
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tmp4 / tmp10
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp11, None)

Triton IR

module {
  tt.func public @triton__0d1d2d34(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<0> : tensor<1x2xi64>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x2xf32>
    %cst_1 = arith.constant dense<2> : tensor<1x2xi32>
    %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
    %1 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<2xi32>) -> tensor<1x2xi32>
    %2 = arith.cmpi slt, %1, %cst_1 : tensor<1x2xi32>
    %3 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1x2x!tt.ptr<f32, 1>>
    %4 = tt.addptr %3, %1 : tensor<1x2x!tt.ptr<f32, 1>>, tensor<1x2xi32>
    %5 = tt.load %4, %2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2xf32>
    %6 = tt.splat %arg2 : (!tt.ptr<i64, 1>) -> tensor<1x2x!tt.ptr<i64, 1>>
    %7 = tt.addptr %6, %1 : tensor<1x2x!tt.ptr<i64, 1>>, tensor<1x2xi32>
    %8 = tt.load %7, %2, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2xi64>
    %9 = arith.select %2, %5, %cst_0 : tensor<1x2xi1>, tensor<1x2xf32>
    %10 = "tt.reduce"(%9) <{axis = 1 : i32}> ({
    ^bb0(%arg5: f32, %arg6: f32):
      %19 = arith.addf %arg5, %arg6 : f32
      tt.reduce.return %19 : f32
    }) : (tensor<1x2xf32>) -> tensor<1xf32>
    %11 = tt.expand_dims %10 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
    %12 = arith.select %2, %8, %cst : tensor<1x2xi1>, tensor<1x2xi64>
    %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({
    ^bb0(%arg5: i64, %arg6: i64):
      %19 = arith.addi %arg5, %arg6 : i64
      tt.reduce.return %19 : i64
    }) : (tensor<1x2xi64>) -> tensor<1xi64>
    %14 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<1xi64>) -> tensor<1x1xi64>
    %15 = arith.sitofp %14 : tensor<1x1xi64> to tensor<1x1xf32>
    %16 = arith.divf %11, %15 : tensor<1x1xf32>
    gpu.barrier
    %17 = tt.addptr %arg0, %c0_i32 : !tt.ptr<f32, 1>, i32
    %18 = tt.splat %17 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
    tt.store %18, %16 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
    tt.return
  }
}

Crash log

/workspace/hongjing/temp/jolwo38w/triton_.ttir:34:11: error: failed to legalize operation 'tt.splat' marked as erased
    %18 = tt.splat %17 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
          ^
/workspace/hongjing/temp/jolwo38w/triton_.ttir:34:11: note: see current operation: %93 = "tt.splat"(%92) {MetaUse} : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
/workspace/hongjing/temp/jolwo38w/triton_.ttir:35:5: note: found live user of result #0: "memref.tensor_store"(%88, %93) : (tensor<1x1xf32>, tensor<1x1x!tt.ptr<f32, 1>>) -> ()
    tt.store %18, %16 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>

Additional information

No response

yuanfz98 commented 10 months ago

@nhat-nguyen This error is from misplacement of tt.splat. We have MetaOpConverter which erase tt.splat op first, then AddPtrConverter + StoreConverter (which uses adaptor to get the mutated ptr of AddPtrConverter). So a tt.splat in the middle is inconsistant as it has been erased. We may:

  1. in canonicalizer pass, move tt.splat before addptr
  2. bottom-top exploring in StoreConverter/LoadConverter

Refs to #62. Thanks for your reply !