microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
134 stars 26 forks source link

[WIP] add a canonicalizer before triton-to-linalg #62

Closed yuanfz98 closed 7 months ago

yuanfz98 commented 7 months ago

This PR is WIP and aims to add a canonicalizer for triton-to-linalg. It decouples the mutation of ttir from triton-to-linalg. RemsiCanonicalizer will postpone expand_dims{axis=1} and provide valid input for PtrAnalysis, as the latter checks rank == 1:

void PtrAnalysis::visitOperandRem(
    arith::RemSIOp remOp, PtrState &state, const Location loc,
    ConversionPatternRewriter &rewriter,
    const llvm::SmallDenseMap<Value, PtrState> &knownPtrs) {
  assert(state.isEmpty());
  visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs);
  assert(state.getRank() == 1 && !state.modulos.back().has_value() &&
         "No support for multiple modulos within an expression");

After RemsiCanonicalizer, %11 = arith.remsi %5, %cst_11 : tensor<8x1xi32> will be :

%11 = arith.remsi %5_clone, %cst_clone : tensor<8xi32>
%12 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<8xi32>) -> tensor<8x1xi32>

Thus it is no longer a rank 2 tensor for PtrAnalysis.

yuanfz98 commented 7 months ago

Attach an example ttir & ast we are facing to:

def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 24576
    rnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex % 384
    x1 = (xindex // 384)
    _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    x3 = xindex
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (384*r2) + (98304*x1)), rmask, other=0).to(tl.float32)
        tmp2 = tl.load(in_ptr1 + (x0 + (384*r2) + (98304*x1)), rmask, other=0)
        tmp1 = tmp0.to(tl.float32)
        tmp3 = tmp1 * tmp2
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp6 = _tmp5 + tmp4
        _tmp5 = tl.where(rmask, tmp6, _tmp5)
    tmp5 = tl.sum(_tmp5, 1)[:, None]
    tl.store(out_ptr0 + (x3), tmp5, None)
module {
  tt.func public @triton__0d1d2d3de4de(%arg0: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<128x8xbf16>
    %c8_i32 = arith.constant 8 : i32
    %c256_i32 = arith.constant 256 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst_0 = arith.constant dense<98304> : tensor<128x1xi32>
    %cst_1 = arith.constant dense<384> : tensor<1x8xi32>
    %cst_2 = arith.constant dense<256> : tensor<1x8xi32>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x8xf32>
    %cst_4 = arith.constant dense<384> : tensor<128x1xi32>
    %c128_i32 = arith.constant 128 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c128_i32 : i32
    %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
    %4 = tt.splat %1 : (i32) -> tensor<128x1xi32>
    %5 = arith.addi %4, %3 : tensor<128x1xi32>
    %6 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
    %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<8xi32>) -> tensor<1x8xi32>
    %8 = arith.remsi %5, %cst_4 : tensor<128x1xi32>
    %9 = arith.divsi %5, %cst_4 : tensor<128x1xi32>
    %10 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x8xi32>
    %11 = arith.muli %9, %cst_0 : tensor<128x1xi32>
    %12 = tt.broadcast %11 : (tensor<128x1xi32>) -> tensor<128x8xi32>
    %13 = tt.splat %arg0 : (!tt.ptr<bf16, 1>) -> tensor<128x8x!tt.ptr<bf16, 1>>
    %14 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<128x8x!tt.ptr<f32, 1>>
    %15 = scf.for %arg5 = %c0_i32 to %c256_i32 step %c8_i32 iter_args(%arg6 = %cst_3) -> (tensor<128x8xf32>)  : i32 {
      %20 = tt.splat %arg5 : (i32) -> tensor<1x8xi32>
      %21 = arith.addi %20, %7 : tensor<1x8xi32>
      %22 = arith.cmpi slt, %21, %cst_2 : tensor<1x8xi32>
      %23 = arith.muli %21, %cst_1 : tensor<1x8xi32>
      %24 = tt.broadcast %23 : (tensor<1x8xi32>) -> tensor<128x8xi32>
      %25 = arith.addi %10, %24 : tensor<128x8xi32>
      %26 = arith.addi %25, %12 : tensor<128x8xi32>
      %27 = tt.addptr %13, %26 : tensor<128x8x!tt.ptr<bf16, 1>>, tensor<128x8xi32>
      %28 = tt.broadcast %22 : (tensor<1x8xi1>) -> tensor<128x8xi1>
      %29 = tt.load %27, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x8xbf16>
      %30 = arith.extf %29 : tensor<128x8xbf16> to tensor<128x8xf32>
      %31 = tt.addptr %14, %26 : tensor<128x8x!tt.ptr<f32, 1>>, tensor<128x8xi32>
      %32 = tt.load %31, %28, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x8xf32>
      %33 = arith.mulf %30, %32 : tensor<128x8xf32>
      %34 = arith.addf %arg6, %33 : tensor<128x8xf32>
      %35 = arith.select %28, %34, %arg6 : tensor<128x8xi1>, tensor<128x8xf32>
      scf.yield %35 : tensor<128x8xf32>
    }
    %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({
    ^bb0(%arg5: f32, %arg6: f32):
      %20 = arith.addf %arg5, %arg6 : f32
      tt.reduce.return %20 : f32
    }) : (tensor<128x8xf32>) -> tensor<128xf32>
    %17 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xf32>) -> tensor<128x1xf32>
    %18 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<128x1x!tt.ptr<f32, 1>>
    %19 = tt.addptr %18, %5 : tensor<128x1x!tt.ptr<f32, 1>>, tensor<128x1xi32>
    tt.store %19, %17 {cache = 1 : i32, evict = 1 : i32} : tensor<128x1xf32>
    tt.return
  }
}
yuanfz98 commented 7 months ago

We may find that the indexing patterns can be represented like:

def f(programId, x, factor0, factor1):
    a = list(range(0, x))
    result = []
    for e in a:
        offset = programId * x
        o = offset + e
        result.append((o // factor0) * factor1 + o % factor0)
    print(result)

With f(0, 256, 7, 72) we got:

[0, 1, 2, 3, 4, 5, 6, 72, 73, 74, 75, 76...]

While f(1, 256, 7, 72) we got:

[2596, 2597, 2598, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2736, 2737, 2738...]

whose pattern isn't obvious to determine.

In fact if we make programId 0, function will be simplified to :

def f(programId, x, factor0, factor1):
    a = list(range(0, x))
    result = []
    for e in a:
        result.append((e // factor0) * factor1 + e % factor0)
    print(result)

I think we should make a compromise here and create a tt.assert. If you have better solution please don't hesitate to share.

nhat-nguyen commented 7 months ago

Thank you for the contribution. I have a small patch to support the modulo pattern that torch inductor generates together with some other small fixes. I will take a look at your proposal for canonicalizing the division operator. Thanks!

nhat-nguyen commented 7 months ago

I really like the idea of canonicalizing before TritonToLinalg so that PtrAnalysis only needs to take care of one pattern. I think for now though, to support the case that you're interested in, the code itself is quite short so we can still have it in PtrAnalysis to keep the complexity low. If this keeps growing, let's revisit the idea. I have a working branch over at nhat/modulo if you're interested in checking out early. I just need to do a bit of cleanup before publishing the PR. Here's the code to support your case:

if (state.getRank() == 1) {
    // Apply the modulo before expanding shape, the common pattern is
    // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] *
    // stride_ak)

    assert(!state.modulos.back().has_value() &&
           "No support for multiple modulos within an expression");

    state.modulos.back() = ModuloState{rhsState.scalar};

  } else if (state.getRank() == 2) {
    // torch inductor expands the tensor shape before applying the modulo
    // operator.
    //
    // We only support either:
    // - (tl.arange(0, end)[:, None] % mod), or
    // - (tl.arange(0, end)[None, :] % mod)
    //
    // In both cases, we apply the modulo to the non-singleton dimension.
    auto shape = cast<TensorType>(remOp.getResult().getType()).getShape();
    if (shape[0] == 1) {
      state.modulos[1] = ModuloState{rhsState.scalar};
    } else if (shape[1] == 1) {
      state.modulos[0] = ModuloState{rhsState.scalar};
    } else {
      assert(false && "Do not support taking modulo on a 2D tensor with no "
                      "singleton dimension");
    }
  } else {
    assert(false && "Unsupported modulo pattern");
  }

Also I would very much appreciate if you could add me as reviewer in future PRs so I can take a look at them in a timely manner. I don't get notifications otherwise. Thanks again!