triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.77k stars 1.54k forks source link

[Bug]Why the ptr type of tt.atomic_rmw don't allow TT_TensorPtr? #4672

Open tfruan2000 opened 1 week ago

tfruan2000 commented 1 week ago

Hi, guys~

I am a bit confused about the definition of tt.atomic_rmw in TritonOps.td.

Currently, the type verification for the ptr and val operands is done using getPointerTypeSameShape.

def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
  SameOperandsAndResultShape,
  SameOperandsAndResultEncoding,
  MemoryEffects<[MemRead<GlobalMemory>]>,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"ptr type matches value type", "val", "ptr", 
                 "getPointerTypeSameShape($_self)">, // here, used `getPointerTypeSameShape`
  ...

However, the behavior of this op is similar to tt.load and tt.store, but in tt.load and tt.store, getPointeeType is used for verification.

def TT_StoreOp : TT_Op<"store", [
  SameLoadStoreOperandsShape,
  SameLoadStoreOperandsEncoding,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"value type matches ptr type", "ptr", "value",
                 "getPointeeType($_self)">,  // here, used `getPointeeType`

This leads to the following IR describing tt.store being valid

  %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16>
  %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
  tt.store %0, %cst : !tt.ptr<tensor<128x32xf16>>

while the corresponding IR for tt.atomic_rmw is considered invalid

error: 'tt.atomic_rmw' op failed to verify that ptr type matches value type
  %1 = tt.atomic_rmw fadd, relaxed, gpu, %0, %cst, %mask : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>
       ^
tmp.mlir:12:8: note: see current operation: %10 = "tt.atomic_rmw"(%9, %8, %arg1) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 1 : i32}> : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>

And the type of ptr in atomic_rmw don't allow TT_PtrLike(!ptr<tensor<>>).

Could some one explain why we don't define tt.atomic_rmw same as tt.load, like:

def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
  SameLoadStoreOperandsAndResultShape, // before: SameOperandsAndResultShape
  SameLoadStoreOperandsAndResultEncoding, // before:  SameOperandsAndResultEncoding,
  MemoryEffects<[MemRead<GlobalMemory>]>,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"ptr type matches value type", "ptr", "val",
                 "getPointeeType($_self)">, // before:  "val", "ptr",  "getPointerTypeSameShape($_self)"
  ...
]> {
    ...
    let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op,
                      AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, // before: TT_PtrLike:$ptr

thx~

tfruan2000 commented 1 week ago

I can make the follow ir legal by change the td file like:

%1 = tt.atomic_rmw fadd, relaxed, gpu, %0, %cst, %mask : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>
image

but I’m not sure if this change is correct

tfruan2000 commented 1 week ago

I noticed that the PR Support block pointer semantics #1392 was the first to add the TT_TensorPtr type to tt.load and tt.store. Could it be that the pointer type for atomic operations was overlooked, or is it meanless for atomic ops to support TT_TensorPtr like load and store