triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.77k stars 1.54k forks source link

TTIR printing inconsistencies #3120

Closed oulgen closed 6 months ago

oulgen commented 7 months ago

I have noticed that when dumping the TTIR from triton, the format is inconsistent. OPs are sometimes printed in their pretty format and sometimes in their generic format, e.g.

%0 = "tt.get_program_id"() <{axis = 0 : i32}> : () -> i32 loc(#loc1)

versus

%0 = tt.get_program_id x : i32 loc(#loc1)

The printed form changes if you slightly modify the source. I propose that we set generic printing flag so that the output is always consistent?

jlebar commented 7 months ago

I would prefer that we always do the non-generic printing. We have lots of pretty-printers.

It seems like there's a bug where axis=0 isn't handled correctly or something and we fall back to the ugly-printer?

oulgen commented 7 months ago

@jlebar I don't actually have a preference which one we choose, as long as we are consistent. Would you be willing to put a PR to set the flag?

And yes, there are inconsistencies between the two printing formats.

joker-eph commented 7 months ago

The only case where MLIR would print generic without you asking for it is if the verifier fails. In this case the entire IR is printing generically though: are you seeing this with one op in the middle of the function or the whole function is printed generically?

oulgen commented 7 months ago

While playing with toy examples, I have seen 3 combinations.

        @triton.jit
        def kernel_with_label(
            in_ptr0,
            in_ptr1,
            out_ptr,
            n_elements,
            BLOCK_SIZE: "tl.constexpr",
        ):
            pid = tl.program_id(axis=0)
            if pid > 1:
                return
            block_start = pid * BLOCK_SIZE
            offsets = block_start + tl.arange(0, BLOCK_SIZE)
            mask = offsets < n_elements
            x = tl.load(in_ptr0 + offsets, mask=mask)
            y = tl.load(in_ptr1 + offsets, mask=mask)
            output = x + y
            tl.store(out_ptr + offsets, output, mask=mask)

prints

module {
  tt.func public @kernel_with_label_0d1d2d3(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oul
gen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0)) attributes {noinline = false} {
    %0 = tt.get_program_id x : i32 loc(#loc1)
    %c1_i32 = arith.constant 1 : i32 loc(#loc2)
    %1 = arith.cmpi sgt, %0, %c1_i32 : i32 loc(#loc2)
    cf.cond_br %1, ^bb1, ^bb2 loc(#loc2)
  ^bb1:  // pred: ^bb0
    tt.return loc(#loc3)
  ^bb2:  // pred: ^bb0
    cf.br ^bb3 loc(#loc4)
  ^bb3:  // pred: ^bb2
    %c4_i32 = arith.constant 4 : i32 loc(#loc5)
    %2 = arith.muli %0, %c4_i32 : i32 loc(#loc5)
    %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc6)
    %4 = tt.splat %2 : (i32) -> tensor<4xi32> loc(#loc7)
    %5 = arith.addi %4, %3 : tensor<4xi32> loc(#loc7)
    %c4_i32_0 = arith.constant 4 : i32 loc(#loc8)
    %cst = arith.constant dense<4> : tensor<4xi32> loc(#loc8)
    %6 = arith.cmpi slt, %5, %cst : tensor<4xi32> loc(#loc8)
    %7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc9)
    %8 = tt.addptr %7, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc9)
    %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4xf32> loc(#loc10)
    %10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc11)
    %11 = tt.addptr %10, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc11)
    %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4xf32> loc(#loc12)
    %13 = arith.addf %9, %12 : tensor<4xf32> loc(#loc13)
    %14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
    %15 = tt.addptr %14, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc14)
    tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32> loc(#loc15)
    tt.return loc(#loc16)
  } loc(#loc)
} loc(#loc)

but changing the if statement to be over tensor values like

        @triton.jit
        def kernel_with_label(
            in_ptr0,
            in_ptr1,
            out_ptr,
            n_elements,
            BLOCK_SIZE: "tl.constexpr",
        ):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            offsets = block_start + tl.arange(0, BLOCK_SIZE)
            mask = offsets < n_elements
            x = tl.load(in_ptr0 + offsets, mask=mask)
            y = tl.load(in_ptr1 + offsets, mask=mask)
            if x > y:
                return
            output = x + y
            tl.store(out_ptr + offsets, output, mask=mask)

gets us

"builtin.module"() ({                                                                                                                                                                                                                      [386/4400]
  "tt.func"() <{arg_attrs = [{tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}], function_type = (!tt.ptr<f32, 1>, !tt.ptr<f32, 1>, !tt.ptr<f32, 1>) -> (), sym_name = "kernel_with_label_0d1d2d3", sym_visibi
lity = "public"}> ({
  ^bb0(%arg0: !tt.ptr<f32, 1> loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg1: !tt.ptr<f32, 1> loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg2: !tt.ptr<f32, 1> loc("/data/user
s/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0)):
    %0 = "tt.get_program_id"() <{axis = 0 : i32}> : () -> i32 loc(#loc1)
    %1 = "arith.constant"() <{value = 4 : i32}> : () -> i32 loc(#loc2)
    %2 = "arith.muli"(%0, %1) : (i32, i32) -> i32 loc(#loc2)
    %3 = "tt.make_range"() <{end = 4 : i32, start = 0 : i32}> : () -> tensor<4xi32> loc(#loc3)
    %4 = "tt.splat"(%2) : (i32) -> tensor<4xi32> loc(#loc4)
    %5 = "arith.addi"(%4, %3) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> loc(#loc4)
    %6 = "arith.constant"() <{value = 4 : i32}> : () -> i32 loc(#loc5)
    %7 = "arith.constant"() <{value = dense<4> : tensor<4xi32>}> : () -> tensor<4xi32> loc(#loc5)
    %8 = "arith.cmpi"(%5, %7) <{predicate = 2 : i64}> : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> loc(#loc5)
    %9 = "tt.splat"(%arg0) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc6)
    %10 = "tt.addptr"(%9, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc6)
    %11 = "tt.load"(%10, %8) <{cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 0>}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi1>) -> tensor<4xf32> loc(#loc7)
    %12 = "tt.splat"(%arg1) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc8)
    %13 = "tt.addptr"(%12, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc8)
    %14 = "tt.load"(%13, %8) <{cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 0>}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi1>) -> tensor<4xf32> loc(#loc9)
    %15 = "arith.constant"() <{value = 1 : i32}> : () -> i32 loc(#loc10)
    %16 = "arith.constant"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32> loc(#loc10)
    %17 = "arith.sitofp"(%16) : (tensor<4xi32>) -> tensor<4xf32> loc(#loc10)
    %18 = "arith.cmpf"(%11, %17) <{fastmath = #arith.fastmath<none>, predicate = 2 : i64}> : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> loc(#loc10)
    "cf.cond_br"(%18)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<4xi1>) -> () loc(#loc10)
  ^bb1:  // pred: ^bb0
    "tt.return"() : () -> () loc(#loc11)
  ^bb2:  // pred: ^bb0
    "cf.br"()[^bb3] : () -> () loc(#loc12)
  ^bb3:  // pred: ^bb2
    %19 = "arith.addf"(%11, %14) <{fastmath = #arith.fastmath<none>}> : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc(#loc13)
    %20 = "tt.splat"(%arg2) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
    %21 = "tt.addptr"(%20, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
    "tt.store"(%21, %19, %8) <{cache = 1 : i32, evict = 1 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> () loc(#loc15)
    "tt.return"() : () -> () loc(#loc16)
  }) {noinline = false} : () -> () loc(#loc)
}) : () -> () loc(#loc)

what's worse is if I include a tl.inline_asm_elementwise in the kernel, I get both formats at the same in the same kernel.

oulgen commented 7 months ago

There's also some inconsistency that I am not exactly certain where it comes from but printing the exact same kernel with same triton rev results in different formats occasionally. Although, I don't have a consistent repro.

ThomasRaoux commented 7 months ago

as pointed out by Medhi here it uses the default printer because the verifier fails.

oulgen commented 7 months ago

Ah, you're right. Calling module.verify() on the second example, results in

error: 'cf.cond_br' op operand #0 must be 1-bit signless integer, but got 'tensor<4xi1>'

I assume in these cases, I can discard it as bad source code?

jlebar commented 7 months ago
            if x > y: # x and y are tensors
                return

It seems like the Triton frontend should raise an error if you write this Python code, and that's the bug you're hitting here.

joker-eph commented 6 months ago

Can we close this bug?

ThomasRaoux commented 6 months ago

I think so