Closed oulgen closed 6 months ago
I would prefer that we always do the non-generic printing. We have lots of pretty-printers.
It seems like there's a bug where axis=0 isn't handled correctly or something and we fall back to the ugly-printer?
@jlebar I don't actually have a preference which one we choose, as long as we are consistent. Would you be willing to put a PR to set the flag?
And yes, there are inconsistencies between the two printing formats.
The only case where MLIR would print generic without you asking for it is if the verifier fails. In this case the entire IR is printing generically though: are you seeing this with one op in the middle of the function or the whole function is printed generically?
While playing with toy examples, I have seen 3 combinations.
@triton.jit
def kernel_with_label(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
if pid > 1:
return
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
prints
module {
tt.func public @kernel_with_label_0d1d2d3(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oul
gen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32} loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%c1_i32 = arith.constant 1 : i32 loc(#loc2)
%1 = arith.cmpi sgt, %0, %c1_i32 : i32 loc(#loc2)
cf.cond_br %1, ^bb1, ^bb2 loc(#loc2)
^bb1: // pred: ^bb0
tt.return loc(#loc3)
^bb2: // pred: ^bb0
cf.br ^bb3 loc(#loc4)
^bb3: // pred: ^bb2
%c4_i32 = arith.constant 4 : i32 loc(#loc5)
%2 = arith.muli %0, %c4_i32 : i32 loc(#loc5)
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc6)
%4 = tt.splat %2 : (i32) -> tensor<4xi32> loc(#loc7)
%5 = arith.addi %4, %3 : tensor<4xi32> loc(#loc7)
%c4_i32_0 = arith.constant 4 : i32 loc(#loc8)
%cst = arith.constant dense<4> : tensor<4xi32> loc(#loc8)
%6 = arith.cmpi slt, %5, %cst : tensor<4xi32> loc(#loc8)
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc9)
%8 = tt.addptr %7, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc9)
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4xf32> loc(#loc10)
%10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc11)
%11 = tt.addptr %10, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc11)
%12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4xf32> loc(#loc12)
%13 = arith.addf %9, %12 : tensor<4xf32> loc(#loc13)
%14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
%15 = tt.addptr %14, %5 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> loc(#loc14)
tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32> loc(#loc15)
tt.return loc(#loc16)
} loc(#loc)
} loc(#loc)
but changing the if statement to be over tensor values like
@triton.jit
def kernel_with_label(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
if x > y:
return
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
gets us
"builtin.module"() ({ [386/4400]
"tt.func"() <{arg_attrs = [{tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}], function_type = (!tt.ptr<f32, 1>, !tt.ptr<f32, 1>, !tt.ptr<f32, 1>) -> (), sym_name = "kernel_with_label_0d1d2d3", sym_visibi
lity = "public"}> ({
^bb0(%arg0: !tt.ptr<f32, 1> loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg1: !tt.ptr<f32, 1> loc("/data/users/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0), %arg2: !tt.ptr<f32, 1> loc("/data/user
s/oulgen/pytorch/test/dynamo/test_triton_kernels.py":1228:0)):
%0 = "tt.get_program_id"() <{axis = 0 : i32}> : () -> i32 loc(#loc1)
%1 = "arith.constant"() <{value = 4 : i32}> : () -> i32 loc(#loc2)
%2 = "arith.muli"(%0, %1) : (i32, i32) -> i32 loc(#loc2)
%3 = "tt.make_range"() <{end = 4 : i32, start = 0 : i32}> : () -> tensor<4xi32> loc(#loc3)
%4 = "tt.splat"(%2) : (i32) -> tensor<4xi32> loc(#loc4)
%5 = "arith.addi"(%4, %3) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> loc(#loc4)
%6 = "arith.constant"() <{value = 4 : i32}> : () -> i32 loc(#loc5)
%7 = "arith.constant"() <{value = dense<4> : tensor<4xi32>}> : () -> tensor<4xi32> loc(#loc5)
%8 = "arith.cmpi"(%5, %7) <{predicate = 2 : i64}> : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> loc(#loc5)
%9 = "tt.splat"(%arg0) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc6)
%10 = "tt.addptr"(%9, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc6)
%11 = "tt.load"(%10, %8) <{cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 0>}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi1>) -> tensor<4xf32> loc(#loc7)
%12 = "tt.splat"(%arg1) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc8)
%13 = "tt.addptr"(%12, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc8)
%14 = "tt.load"(%13, %8) <{cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 0>}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi1>) -> tensor<4xf32> loc(#loc9)
%15 = "arith.constant"() <{value = 1 : i32}> : () -> i32 loc(#loc10)
%16 = "arith.constant"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32> loc(#loc10)
%17 = "arith.sitofp"(%16) : (tensor<4xi32>) -> tensor<4xf32> loc(#loc10)
%18 = "arith.cmpf"(%11, %17) <{fastmath = #arith.fastmath<none>, predicate = 2 : i64}> : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> loc(#loc10)
"cf.cond_br"(%18)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<4xi1>) -> () loc(#loc10)
^bb1: // pred: ^bb0
"tt.return"() : () -> () loc(#loc11)
^bb2: // pred: ^bb0
"cf.br"()[^bb3] : () -> () loc(#loc12)
^bb3: // pred: ^bb2
%19 = "arith.addf"(%11, %14) <{fastmath = #arith.fastmath<none>}> : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc(#loc13)
%20 = "tt.splat"(%arg2) : (!tt.ptr<f32, 1>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
%21 = "tt.addptr"(%20, %5) : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>) -> tensor<4x!tt.ptr<f32, 1>> loc(#loc14)
"tt.store"(%21, %19, %8) <{cache = 1 : i32, evict = 1 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> () loc(#loc15)
"tt.return"() : () -> () loc(#loc16)
}) {noinline = false} : () -> () loc(#loc)
}) : () -> () loc(#loc)
what's worse is if I include a tl.inline_asm_elementwise
in the kernel, I get both formats at the same in the same kernel.
There's also some inconsistency that I am not exactly certain where it comes from but printing the exact same kernel with same triton rev results in different formats occasionally. Although, I don't have a consistent repro.
as pointed out by Medhi here it uses the default printer because the verifier fails.
Ah, you're right. Calling module.verify()
on the second example, results in
error: 'cf.cond_br' op operand #0 must be 1-bit signless integer, but got 'tensor<4xi1>'
I assume in these cases, I can discard it as bad source code?
if x > y: # x and y are tensors
return
It seems like the Triton frontend should raise an error if you write this Python code, and that's the bug you're hitting here.
Can we close this bug?
I think so
I have noticed that when dumping the TTIR from triton, the format is inconsistent. OPs are sometimes printed in their pretty format and sometimes in their generic format, e.g.
versus
The printed form changes if you slightly modify the source. I propose that we set generic printing flag so that the output is always consistent?