Open manbearian opened 9 months ago
@nhat-nguyen Would you mind taking a look at this issue ?
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr2, xnumel, rnumel):
xnumel = 16384
XBLOCK: tl.constexpr = 1
rnumel = 384
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (384*x0)), rmask, other=0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0)
tmp8 = tl.load(in_ptr2 + (r1 + (384*x0)), rmask, other=0)
tmp9 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp11 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
tmp18 = tl.load(in_out_ptr0 + (r1 + (384*x0)), rmask, other=0)
tmp27 = tl.load(in_ptr5 + (r1 + (384*x0)), rmask)
tmp32 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
tmp6 = tl.where(rmask, tmp4, 0)
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
tmp10 = tmp8 - tmp9
tmp12 = tmp10 * tmp11
tmp13 = tmp3 * tmp12
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp16 = tl.where(rmask, tmp14, 0)
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
tmp19 = 384.0
tmp20 = tmp11 / tmp19
tmp21 = tmp3 * tmp19
tmp22 = tmp21 - tmp7
tmp23 = tmp12 * tmp17
tmp24 = tmp22 - tmp23
tmp25 = tmp20 * tmp24
tmp26 = tmp18 + tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 1.25
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp33 = tl.where(tmp32 < 0, tmp32 + 65, tmp32)
tmp34 = tl.full([1], -1, tl.int64)
tmp35 = tmp32 == tmp34
tmp36 = 0.0
tmp37 = tl.where(tmp35, tmp36, tmp31)
tl.store(in_out_ptr0 + (r1 + (384*x0)), tmp31, rmask)
tl.atomic_add(out_ptr2 + (tl.broadcast_to(r1 + (384*tmp33), [RBLOCK])), tmp37, rmask)
module {
tt.func public @triton__0d1d2d3d4d5d6d7d8d9de10de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i1, 1> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%c384_i32 = arith.constant 384 : i32
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<384> : tensor<1xi64>
%cst_1 = arith.constant dense<65> : tensor<1xi64>
%cst_2 = arith.constant dense<0> : tensor<1xi64>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<512xf32>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<512xbf16>
%cst_5 = arith.constant dense<-1> : tensor<1xi64>
%cst_6 = arith.constant dense<1.250000e+00> : tensor<512xf32>
%cst_7 = arith.constant dense<3.840000e+02> : tensor<512xf32>
%cst_8 = arith.constant dense<3.840000e+02> : tensor<1xf32>
%cst_9 = arith.constant dense<384> : tensor<512xi32>
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
%2 = arith.cmpi slt, %1, %cst_9 : tensor<512xi32>
%3 = arith.muli %0, %c384_i32 : i32
%4 = tt.splat %3 : (i32) -> tensor<512xi32>
%5 = arith.addi %1, %4 : tensor<512xi32>
%6 = tt.splat %arg1 : (!tt.ptr<bf16, 1>) -> tensor<512x!tt.ptr<bf16, 1>>
%7 = tt.addptr %6, %5 : tensor<512x!tt.ptr<bf16, 1>>, tensor<512xi32>
%8 = tt.load %7, %2, %cst_4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xbf16>
%9 = arith.extf %8 : tensor<512xbf16> to tensor<512xf32>
%10 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
%11 = tt.addptr %10, %1 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
%12 = tt.load %11, %2, %cst_3 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<512xf32>
%13 = tt.splat %arg3 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
%14 = tt.addptr %13, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
%15 = tt.load %14, %2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
%16 = tt.addptr %arg4, %0 : !tt.ptr<f32, 1>, i32
%17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
%18 = tt.load %17 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
%19 = tt.addptr %arg5, %0 : !tt.ptr<f32, 1>, i32
%20 = tt.splat %19 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
%21 = tt.load %20 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
%22 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
%23 = tt.addptr %22, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
%24 = tt.load %23, %2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
%25 = tt.splat %arg6 : (!tt.ptr<i1, 1>) -> tensor<512x!tt.ptr<i1, 1>>
%26 = tt.addptr %25, %5 : tensor<512x!tt.ptr<i1, 1>>, tensor<512xi32>
%27 = tt.bitcast %26 : tensor<512x!tt.ptr<i1, 1>> -> tensor<512x!tt.ptr<i8, 1>>
%28 = tt.load %27, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>
%29 = tt.addptr %arg7, %0 : !tt.ptr<i64, 1>, i32
%30 = tt.splat %29 : (!tt.ptr<i64, 1>) -> tensor<1x!tt.ptr<i64, 1>>
%31 = tt.load %30 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xi64>
%32 = arith.mulf %9, %12 : tensor<512xf32>
%33 = arith.select %2, %32, %cst_3 : tensor<512xi1>, tensor<512xf32>
%34 = "tt.reduce"(%33) <{axis = 0 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
%70 = arith.addf %arg11, %arg12 : f32
tt.reduce.return %70 : f32
}) : (tensor<512xf32>) -> f32
%35 = arith.addf %34, %cst : f32
%36 = tt.broadcast %18 : (tensor<1xf32>) -> tensor<512xf32>
%37 = arith.subf %15, %36 : tensor<512xf32>
%38 = tt.broadcast %21 : (tensor<1xf32>) -> tensor<512xf32>
%39 = arith.mulf %37, %38 : tensor<512xf32>
%40 = arith.mulf %32, %39 : tensor<512xf32>
%41 = arith.select %2, %40, %cst_3 : tensor<512xi1>, tensor<512xf32>
%42 = "tt.reduce"(%41) <{axis = 0 : i32}> ({
^bb0(%arg11: f32, %arg12: f32):
%70 = arith.addf %arg11, %arg12 : f32
tt.reduce.return %70 : f32
}) : (tensor<512xf32>) -> f32
%43 = arith.addf %42, %cst : f32
%44 = arith.divf %21, %cst_8 : tensor<1xf32>
%45 = arith.mulf %32, %cst_7 : tensor<512xf32>
%46 = tt.splat %35 : (f32) -> tensor<512xf32>
%47 = arith.subf %45, %46 : tensor<512xf32>
%48 = tt.splat %43 : (f32) -> tensor<512xf32>
%49 = arith.mulf %39, %48 : tensor<512xf32>
%50 = arith.subf %47, %49 : tensor<512xf32>
%51 = tt.broadcast %44 : (tensor<1xf32>) -> tensor<512xf32>
%52 = arith.mulf %51, %50 : tensor<512xf32>
%53 = arith.addf %24, %52 : tensor<512xf32>
%54 = arith.sitofp %28 : tensor<512xi8> to tensor<512xf32>
%55 = arith.mulf %54, %cst_6 : tensor<512xf32>
%56 = arith.mulf %53, %55 : tensor<512xf32>
%57 = arith.cmpi slt, %31, %cst_2 : tensor<1xi64>
%58 = arith.addi %31, %cst_1 : tensor<1xi64>
%59 = arith.select %57, %58, %31 : tensor<1xi1>, tensor<1xi64>
%60 = arith.cmpi eq, %31, %cst_5 : tensor<1xi64>
%61 = tt.broadcast %60 : (tensor<1xi1>) -> tensor<512xi1>
%62 = arith.select %61, %cst_3, %56 : tensor<512xi1>, tensor<512xf32>
tt.store %23, %56, %2 {cache = 1 : i32, evict = 1 : i32} : tensor<512xf32>
%63 = arith.muli %59, %cst_0 : tensor<1xi64>
%64 = tt.broadcast %63 : (tensor<1xi64>) -> tensor<512xi64>
%65 = arith.extsi %1 : tensor<512xi32> to tensor<512xi64>
%66 = arith.addi %65, %64 : tensor<512xi64>
%67 = tt.splat %arg8 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
%68 = tt.addptr %67, %66 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi64>
%69 = "tt.atomic_rmw"(%68, %62, %2) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<512x!tt.ptr<f32, 1>>, tensor<512xf32>, tensor<512xi1>) -> tensor<512xf32>
tt.return
}
}
We should support :
@yuanfz98 Sorry for the late response.
The case you're hitting is basically the limit of our PtrAnalysis
pass. We attempt to convert triton pointer load on a best effort basis: pointers that point to contiguous memory that can be determined at compile time will be converted to a single memref load. The common cases almost always involve having tl.arange
as the base case during the recursion, that's where we set up the initial offsets, strides, and sizes of the memref to load from.
For this particular case, the offsets being used by tl.atomic_add(out_ptr2 + (tl.broadcast_to(r1 + (384*tmp33), [RBLOCK])), tmp37, rmask)
are obtained from another load:
tmp32 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
It is impossible to know at compile time whether tmp3
only contains offsets that are contiguous, and so we can't really support this case with the current approach. If we somehow have a user annotation that guarantees tmp3
are contiguous, then it will be possible.
Another feature that we have in mind but isn't the priority is to have a fallback mode; this mode will support cases where PtrAnalysis
fails at the expense of performance. This will probably involve loading each individual element one at a time and storing them into the local buffer. It would take me some time to give this some more thoughts on how we can make this possible (I'll be away all December). In the meantime, if you have any ideas or further thoughts on this, we would love to hear in the discussions tab.
Now, for the select
op, technically we can support it provided that both the true
and false
values don't end up being produced by a load. I think the code would look something like:
visitSelect() {
PtrState trueValState;
visit(trueVal);
PtrState falseValState;
visit(falseVal);
PtrState cond;
visit(cond);
// now we need to generate code to select the correct strides and offsets depending on the condition
}
created from #7.
I don't know what original Triton code looked like that created this, but there is a select in the address expression.
Not in the error message, but when i looked at the provided test case i see an atomic operation which is also not currently supported.
repro.zip
triton-shared-opt -triton-to-linalg 43.mlir
Error output: