triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.69k stars 1.53k forks source link

Tree-like join with multiple times layout transfer are quite slow #4030

Open bxyb opened 3 months ago

bxyb commented 3 months ago

I'm trying to take use of the whole randint4x results (4 int32 rand number) and generate 4x 8bit rand number from each 1x 32bit rand number. (Similar trick as dropout done by FlashAttention)

@triton.jit
def rand8bit(seed, offset):
    a, b, c, d = tl.randint4x(seed, offset, 1)
    aa, ab, ac, ad, \
    ba, bb, bc, bd, \
    ca, cb, cc, cd, \
    da, db, dc, dd \
        = tl.inline_asm_elementwise(
        asm="""
        {
            // Unpack `a` into `ai`.
            .reg .b8 tmp<4>;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $16;
            cvt.u16.u8 $0, tmp0;
            cvt.u16.u8 $1, tmp1;
            cvt.u16.u8 $2, tmp2;
            cvt.u16.u8 $3, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $17;
            cvt.u16.u8 $4, tmp0;
            cvt.u16.u8 $5, tmp1;
            cvt.u16.u8 $6, tmp2;
            cvt.u16.u8 $7, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $18;
            cvt.u16.u8 $8, tmp0;
            cvt.u16.u8 $9, tmp1;
            cvt.u16.u8 $10, tmp2;
            cvt.u16.u8 $11, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $19;
            cvt.u16.u8 $12, tmp0;
            cvt.u16.u8 $13, tmp1;
            cvt.u16.u8 $14, tmp2;
            cvt.u16.u8 $15, tmp3;
        }
        """,
        constraints=(
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "r,r,r,r"),
        args=[a, b, c, d],
        dtype=[
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
        ],
        is_pure=True,
        pack=1,
    )
    return tl.join(
        tl.join(
            tl.join(tl.join(aa, ab), tl.join(ac, ad)),
            tl.join(tl.join(ba, bb), tl.join(bc, bd))
        ),
        tl.join(
            tl.join(tl.join(ca, cb), tl.join(cc, cd)),
            tl.join(tl.join(da, db), tl.join(dc, dd))
        ),
    )

And I find out the generated mlir are quite inefficient. There are multiple times layout transfer for the joined tensor.

Focus on join part, the generated mlir before TritonGPUCoalesce (after ConvertTritonToTritonGPU):

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#blocked6 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 8, 2], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#blocked7 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 2], threadsPerWarp = [2, 8, 2, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 2, 1, 0]}>
#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 8, 2, 2], warpsPerCTA = [4, 1, 1, 1], order = [3, 2, 1, 0]}>
#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 2], threadsPerWarp = [1, 8, 2, 2, 1], warpsPerCTA = [4, 1, 1, 1, 1], order = [4, 3, 2, 1, 0]}>
#blocked10 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 1], threadsPerWarp = [1, 4, 2, 2, 2], warpsPerCTA = [2, 2, 1, 1, 1], order = [4, 3, 2, 1, 0]}>
#blocked11 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 1, 2], threadsPerWarp = [1, 4, 2, 2, 2, 1], warpsPerCTA = [2, 2, 1, 1, 1, 1], order = [5, 4, 3, 2, 1, 0]}>
#blocked12 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 2, 2, 2, 2, 2], warpsPerCTA = [1, 4, 1, 1, 1, 1], order = [5, 4, 3, 2, 1, 0]}>
#blocked13 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>

....
%44 = tt.join %43#0, %43#1 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc74)
%45 = triton_gpu.convert_layout %44 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc75)
%46 = tt.join %43#2, %43#3 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc76)
%47 = triton_gpu.convert_layout %46 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc75)
%48 = tt.join %45, %47 : tensor<32x8x2xi16, #blocked6> -> tensor<32x8x2x2xi16, #blocked7> loc(#loc75)
%49 = triton_gpu.convert_layout %48 : tensor<32x8x2x2xi16, #blocked7> -> tensor<32x8x2x2xi16, #blocked8> loc(#loc77)
%50 = tt.join %43#4, %43#5 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc78)
%51 = triton_gpu.convert_layout %50 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc79)
%52 = tt.join %43#6, %43#7 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc80)
%53 = triton_gpu.convert_layout %52 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc79)
%54 = tt.join %51, %53 : tensor<32x8x2xi16, #blocked6> -> tensor<32x8x2x2xi16, #blocked7> loc(#loc79)
%55 = triton_gpu.convert_layout %54 : tensor<32x8x2x2xi16, #blocked7> -> tensor<32x8x2x2xi16, #blocked8> loc(#loc77)
%56 = tt.join %49, %55 : tensor<32x8x2x2xi16, #blocked8> -> tensor<32x8x2x2x2xi16, #blocked9> loc(#loc77)
%57 = triton_gpu.convert_layout %56 : tensor<32x8x2x2x2xi16, #blocked9> -> tensor<32x8x2x2x2xi16, #blocked10> loc(#loc81)
%58 = tt.join %43#8, %43#9 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc82)
%59 = triton_gpu.convert_layout %58 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc83)
%60 = tt.join %43#10, %43#11 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc84)
%61 = triton_gpu.convert_layout %60 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc83)
%62 = tt.join %59, %61 : tensor<32x8x2xi16, #blocked6> -> tensor<32x8x2x2xi16, #blocked7> loc(#loc83)
%63 = triton_gpu.convert_layout %62 : tensor<32x8x2x2xi16, #blocked7> -> tensor<32x8x2x2xi16, #blocked8> loc(#loc85)
%64 = tt.join %43#12, %43#13 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc86)
%65 = triton_gpu.convert_layout %64 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc87)
%66 = tt.join %43#14, %43#15 : tensor<32x8xi16, #blocked> -> tensor<32x8x2xi16, #blocked5> loc(#loc88)
%67 = triton_gpu.convert_layout %66 : tensor<32x8x2xi16, #blocked5> -> tensor<32x8x2xi16, #blocked6> loc(#loc87)
%68 = tt.join %65, %67 : tensor<32x8x2xi16, #blocked6> -> tensor<32x8x2x2xi16, #blocked7> loc(#loc87)
%69 = triton_gpu.convert_layout %68 : tensor<32x8x2x2xi16, #blocked7> -> tensor<32x8x2x2xi16, #blocked8> loc(#loc85)
%70 = tt.join %63, %69 : tensor<32x8x2x2xi16, #blocked8> -> tensor<32x8x2x2x2xi16, #blocked9> loc(#loc85)
%71 = triton_gpu.convert_layout %70 : tensor<32x8x2x2x2xi16, #blocked9> -> tensor<32x8x2x2x2xi16, #blocked10> loc(#loc81)
%72 = tt.join %57, %71 : tensor<32x8x2x2x2xi16, #blocked10> -> tensor<32x8x2x2x2x2xi16, #blocked11> loc(#loc81)
%73 = triton_gpu.convert_layout %72 : tensor<32x8x2x2x2x2xi16, #blocked11> -> tensor<32x8x2x2x2x2xi16, #blocked12> loc(#loc52)

The final generated ptx / sass reserve the multiple layout transfer via share memory store and load.

.loc    4 56 37
bar.sync    0;
mov.b32     %r574, {%rs240, %rs239};
st.shared.u32   [%r35], %r574;
bar.sync    0;
ld.shared.u16   %rs241, [%r36];
ld.shared.u16   %rs242, [%r37];
bar.sync    0;
mov.b32     %r575, {%rs238, %rs237};
st.shared.u32   [%r35], %r575;
bar.sync    0;
ld.shared.u16   %rs243, [%r36];
ld.shared.u16   %rs244, [%r37];
.loc    4 57 12
bar.sync    0;
mov.b32     %r576, {%rs241, %rs243};
st.shared.u32   [%r38], %r576;
mov.b32     %r577, {%rs242, %rs244};
st.shared.u32   [%r39], %r577;
bar.sync    0;
....

Could the tree-like join transfer the layout only once at the end? I'm interested in improving this, but I haven't been able to identify which part of the TritonToTritonGPUPass inserts the convert_layout operation.

jlebar commented 3 months ago

I haven't been able to identify which part of the TritonToTritonGPUPass inserts the convert_layout operation.

Run with MLIR_ENABLE_DUMP=1 to see the MLIR after each pass. (This and other tips are in the README.)

bxyb commented 3 months ago

I haven't been able to identify which part of the TritonToTritonGPUPass inserts the convert_layout operation.

Run with MLIR_ENABLE_DUMP=1 to see the MLIR after each pass. (This and other tips are in the README.)

I have already positioned the op triton_gpu.convert_layout was added via ConvertTritonToTritonGPU through MLIR_ENABLE_DUMP.

But the rewrite around the join op do not add the convert_layout directly.

struct TritonJoinOpPattern : public OpConversionPattern<triton::JoinOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
    // Simply rely on type inference for this op.  (Notably, GenericOpPattern
    // does not do this, instead it assigns the default layout to the ins and
    // outs.)
    addNamedAttrs(rewriter.replaceOpWithNewOp<triton::JoinOp>(
                      op, adaptor.getLhs(), adaptor.getRhs()),
                  adaptor.getAttributes());
    return success();
  }
};

I recently noticed that @jlebar has made significant contributions to the 'join' operation. Could you provide some insights into the mechanism by which the 'convert_layout' is triggered by the 'join' operation?

jlebar commented 3 months ago

Ah, sorry. To see which part of the pass is inserting the operations, dump the function at various points in the pass.

But I think your concern should be with the pass that eliminates unnecessary layout conversions, not the pass that adds them. After TritonToTritonGPUPass, there will be many unnecessary layout conversions, then a later pass should remove unnecessary ones.

bxyb commented 3 months ago

Ah, sorry. To see which part of the pass is inserting the operations, dump the function at various points in the pass.

But I think your concern should be with the pass that eliminates unnecessary layout conversions, not the pass that adds them. After TritonToTritonGPUPass, there will be many unnecessary layout conversions, then a later pass should remove unnecessary ones.

Hmm, are there any specific constraints on the layout of the operands to be joined? From your perspective, was this layout switch here (for tree-like join) unnecessary?

I'll try to add the elimination process handling cases like this in the TritonGPURemoveLayoutConversionsPass.

jlebar commented 3 months ago

Hmm, are there any specific constraints on the layout of the operands to be joined?

See the MLIR verifier for the join op.

From your perspective, was this layout switch here (for tree-like join) unnecessary?

I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.

bxyb commented 3 months ago

Hmm, are there any specific constraints on the layout of the operands to be joined?

See the MLIR verifier for the join op.

From your perspective, was this layout switch here (for tree-like join) unnecessary?

I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.

here is the whole repro.

import triton
from triton import language as tl

@triton.jit
def rand8bit(seed, offset):
    a, b, c, d = tl.randint4x(seed, offset, 1)
    aa, ab, ac, ad, \
    ba, bb, bc, bd, \
    ca, cb, cc, cd, \
    da, db, dc, dd \
        = tl.inline_asm_elementwise(
        asm="""
        {
            // Unpack `a` into `ai`.
            .reg .b8 tmp<4>;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $16;
            cvt.u16.u8 $0, tmp0;
            cvt.u16.u8 $1, tmp1;
            cvt.u16.u8 $2, tmp2;
            cvt.u16.u8 $3, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $17;
            cvt.u16.u8 $4, tmp0;
            cvt.u16.u8 $5, tmp1;
            cvt.u16.u8 $6, tmp2;
            cvt.u16.u8 $7, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $18;
            cvt.u16.u8 $8, tmp0;
            cvt.u16.u8 $9, tmp1;
            cvt.u16.u8 $10, tmp2;
            cvt.u16.u8 $11, tmp3;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $19;
            cvt.u16.u8 $12, tmp0;
            cvt.u16.u8 $13, tmp1;
            cvt.u16.u8 $14, tmp2;
            cvt.u16.u8 $15, tmp3;
        }
        """,
        constraints=(
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "=r,=r,=r,=r," 
            "r,r,r,r"),
        args=[a, b, c, d],
        dtype=[
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
            tl.uint16, tl.uint16, tl.uint16, tl.uint16,
        ],
        is_pure=True,
        pack=1,
    )
    return tl.join(
        tl.join(
            tl.join(tl.join(aa, ab), tl.join(ac, ad)),
            tl.join(tl.join(ba, bb), tl.join(bc, bd))
        ),
        tl.join(
            tl.join(tl.join(ca, cb), tl.join(cc, cd)),
            tl.join(tl.join(da, db), tl.join(dc, dd))
        ),
    )

if __name__ == "__main__":
    import torch
    # test triton mask
    @triton.jit
    def gen_rand(
        seed,
        rand_v,
        shape_x, shape_y,
        BLOCK_X: tl.constexpr,
        BLOCK_Y: tl.constexpr,
    ):
        idx = tl.program_id(0)
        idy = tl.program_id(1)
        off_x = idx * BLOCK_X + tl.arange(0, BLOCK_X)
        off_y = idy * BLOCK_Y + tl.arange(0, BLOCK_Y//16) * 16
        off = off_x[:, None] * shape_y + off_y[None, :]
        for _ in range(0, 1, 1):
            got = tl.reshape(rand8bit(seed, off), [BLOCK_X, BLOCK_Y])
            v = tl.make_block_ptr(
                base=rand_v,
                shape=(shape_x, shape_y),
                strides=(shape_y, 1),
                offsets=(idx * BLOCK_X, idy * BLOCK_Y),
                block_shape=(BLOCK_X, BLOCK_Y),
                order=(0, 1),
            )
            tl.store(v, got.to(tl.int16))

    shape_x, shape_y = 2048, 4096 
    seed = 1991

    grid = lambda META: (triton.cdiv(shape_x, META["BLOCK_X"]), triton.cdiv(shape_y, META["BLOCK_Y"]))

    rand_v_1 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
    gen_rand[grid](seed, rand_v_1, shape_x, shape_y, 
             BLOCK_X=128, BLOCK_Y=32)

    rand_v_2 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
    print("diff before:", (rand_v_1 - rand_v_2).abs().sum())
    gen_rand[grid](seed, rand_v_2, shape_x, shape_y, 
             BLOCK_X=32, BLOCK_Y=128)
    print("diff after:", (rand_v_1 - rand_v_2).abs().sum())
    print(rand_v_2)
jlebar commented 3 months ago

Yeah, I don't see why we shouldn't be able to do this with only one layout transformation, during reshape.

On Wed, May 29, 2024 at 10:38 AM bxyb @.***> wrote:

Hmm, are there any specific constraints on the layout of the operands to be joined?

See the MLIR verifier for the join op.

From your perspective, was this layout switch here (for tree-like join) unnecessary?

I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.

here is the whole repro.

import tritonfrom triton import language as tl @triton.jitdef rand8bit(seed, offset): a, b, c, d = tl.randint4x(seed, offset, 1) aa, ab, ac, ad, \ ba, bb, bc, bd, \ ca, cb, cc, cd, \ da, db, dc, dd \ = tl.inline_asm_elementwise( asm=""" { // Unpack a into ai. .reg .b8 tmp<4>; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $16; cvt.u16.u8 $0, tmp0; cvt.u16.u8 $1, tmp1; cvt.u16.u8 $2, tmp2; cvt.u16.u8 $3, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $17; cvt.u16.u8 $4, tmp0; cvt.u16.u8 $5, tmp1; cvt.u16.u8 $6, tmp2; cvt.u16.u8 $7, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $18; cvt.u16.u8 $8, tmp0; cvt.u16.u8 $9, tmp1; cvt.u16.u8 $10, tmp2; cvt.u16.u8 $11, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $19; cvt.u16.u8 $12, tmp0; cvt.u16.u8 $13, tmp1; cvt.u16.u8 $14, tmp2; cvt.u16.u8 $15, tmp3; } """, constraints=( "=r,=r,=r,=r," "=r,=r,=r,=r," "=r,=r,=r,=r," "=r,=r,=r,=r," "r,r,r,r"), args=[a, b, c, d], dtype=[ tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, ], is_pure=True, pack=1, ) return tl.join( tl.join( tl.join(tl.join(aa, ab), tl.join(ac, ad)), tl.join(tl.join(ba, bb), tl.join(bc, bd)) ), tl.join( tl.join(tl.join(ca, cb), tl.join(cc, cd)), tl.join(tl.join(da, db), tl.join(dc, dd)) ), ) if name == "main": import torch

test triton mask

@triton.jit
def gen_rand(
    seed,
    rand_v,
    shape_x, shape_y,
    BLOCK_X: tl.constexpr,
    BLOCK_Y: tl.constexpr,
):
    idx = tl.program_id(0)
    idy = tl.program_id(1)
    off_x = idx * BLOCK_X + tl.arange(0, BLOCK_X)
    off_y = idy * BLOCK_Y + tl.arange(0, BLOCK_Y//16) * 16
    off = off_x[:, None] * shape_y + off_y[None, :]
    for _ in range(0, 1, 1):
        got = tl.reshape(rand8bit(seed, off), [BLOCK_X, BLOCK_Y])
        v = tl.make_block_ptr(
            base=rand_v,
            shape=(shape_x, shape_y),
            strides=(shape_y, 1),
            offsets=(idx * BLOCK_X, idy * BLOCK_Y),
            block_shape=(BLOCK_X, BLOCK_Y),
            order=(0, 1),
        )
        tl.store(v, got.to(tl.int16))

shape_x, shape_y = 2048, 4096
seed = 1991

grid = lambda META: (triton.cdiv(shape_x, META["BLOCK_X"]), triton.cdiv(shape_y, META["BLOCK_Y"]))

rand_v_1 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
gen_rand[grid](seed, rand_v_1, shape_x, shape_y,
         BLOCK_X=128, BLOCK_Y=32)

rand_v_2 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
print("diff before:", (rand_v_1 - rand_v_2).abs().sum())
gen_rand[grid](seed, rand_v_2, shape_x, shape_y,
         BLOCK_X=32, BLOCK_Y=128)
print("diff after:", (rand_v_1 - rand_v_2).abs().sum())
print(rand_v_2)

— Reply to this email directly, view it on GitHub https://github.com/triton-lang/triton/issues/4030#issuecomment-2137938274, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABEZBZCCXNRYRJEWGZYL7TZEYHALAVCNFSM6AAAAABIOXPNESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXHEZTQMRXGQ . You are receiving this because you were mentioned.Message ID: @.***>

bxyb commented 3 months ago

test.txt

When using test.txt (as mlir input) to test tritongpu-remove-layout-conversions, all convert_layout ops could be removed. But with the original input (commented block settings in the test.txt), it could only remove the blocked 5, 6 (directly join blocked 4 to blocked 6). It could not remove all the layout transforms due to the final block setting was fixed. (sizePerThread=[1, 1, 1, 2, 2, 2], could not join 16 elements without transform)

Why It's 8 per thread? TritonGPUCoalesce set it.

So, if i only take 8 elments to join, it performs as expect (no layout transfer).

gen_rand_py.txt