Open bxyb opened 3 months ago
I haven't been able to identify which part of the TritonToTritonGPUPass inserts the convert_layout operation.
Run with MLIR_ENABLE_DUMP=1 to see the MLIR after each pass. (This and other tips are in the README.)
I haven't been able to identify which part of the TritonToTritonGPUPass inserts the convert_layout operation.
Run with MLIR_ENABLE_DUMP=1 to see the MLIR after each pass. (This and other tips are in the README.)
I have already positioned the op triton_gpu.convert_layout
was added via ConvertTritonToTritonGPU through MLIR_ENABLE_DUMP.
But the rewrite around the join op do not add the convert_layout directly.
struct TritonJoinOpPattern : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Simply rely on type inference for this op. (Notably, GenericOpPattern
// does not do this, instead it assigns the default layout to the ins and
// outs.)
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::JoinOp>(
op, adaptor.getLhs(), adaptor.getRhs()),
adaptor.getAttributes());
return success();
}
};
I recently noticed that @jlebar has made significant contributions to the 'join' operation. Could you provide some insights into the mechanism by which the 'convert_layout' is triggered by the 'join' operation?
Ah, sorry. To see which part of the pass is inserting the operations, dump the function at various points in the pass.
But I think your concern should be with the pass that eliminates unnecessary layout conversions, not the pass that adds them. After TritonToTritonGPUPass, there will be many unnecessary layout conversions, then a later pass should remove unnecessary ones.
Ah, sorry. To see which part of the pass is inserting the operations, dump the function at various points in the pass.
But I think your concern should be with the pass that eliminates unnecessary layout conversions, not the pass that adds them. After TritonToTritonGPUPass, there will be many unnecessary layout conversions, then a later pass should remove unnecessary ones.
Hmm, are there any specific constraints on the layout of the operands to be joined? From your perspective, was this layout switch here (for tree-like join) unnecessary?
I'll try to add the elimination process handling cases like this in the TritonGPURemoveLayoutConversionsPass.
Hmm, are there any specific constraints on the layout of the operands to be joined?
See the MLIR verifier for the join op.
From your perspective, was this layout switch here (for tree-like join) unnecessary?
I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.
Hmm, are there any specific constraints on the layout of the operands to be joined?
See the MLIR verifier for the join op.
From your perspective, was this layout switch here (for tree-like join) unnecessary?
I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.
here is the whole repro.
import triton
from triton import language as tl
@triton.jit
def rand8bit(seed, offset):
a, b, c, d = tl.randint4x(seed, offset, 1)
aa, ab, ac, ad, \
ba, bb, bc, bd, \
ca, cb, cc, cd, \
da, db, dc, dd \
= tl.inline_asm_elementwise(
asm="""
{
// Unpack `a` into `ai`.
.reg .b8 tmp<4>;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $16;
cvt.u16.u8 $0, tmp0;
cvt.u16.u8 $1, tmp1;
cvt.u16.u8 $2, tmp2;
cvt.u16.u8 $3, tmp3;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $17;
cvt.u16.u8 $4, tmp0;
cvt.u16.u8 $5, tmp1;
cvt.u16.u8 $6, tmp2;
cvt.u16.u8 $7, tmp3;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $18;
cvt.u16.u8 $8, tmp0;
cvt.u16.u8 $9, tmp1;
cvt.u16.u8 $10, tmp2;
cvt.u16.u8 $11, tmp3;
mov.b32 {tmp0, tmp1, tmp2, tmp3}, $19;
cvt.u16.u8 $12, tmp0;
cvt.u16.u8 $13, tmp1;
cvt.u16.u8 $14, tmp2;
cvt.u16.u8 $15, tmp3;
}
""",
constraints=(
"=r,=r,=r,=r,"
"=r,=r,=r,=r,"
"=r,=r,=r,=r,"
"=r,=r,=r,=r,"
"r,r,r,r"),
args=[a, b, c, d],
dtype=[
tl.uint16, tl.uint16, tl.uint16, tl.uint16,
tl.uint16, tl.uint16, tl.uint16, tl.uint16,
tl.uint16, tl.uint16, tl.uint16, tl.uint16,
tl.uint16, tl.uint16, tl.uint16, tl.uint16,
],
is_pure=True,
pack=1,
)
return tl.join(
tl.join(
tl.join(tl.join(aa, ab), tl.join(ac, ad)),
tl.join(tl.join(ba, bb), tl.join(bc, bd))
),
tl.join(
tl.join(tl.join(ca, cb), tl.join(cc, cd)),
tl.join(tl.join(da, db), tl.join(dc, dd))
),
)
if __name__ == "__main__":
import torch
# test triton mask
@triton.jit
def gen_rand(
seed,
rand_v,
shape_x, shape_y,
BLOCK_X: tl.constexpr,
BLOCK_Y: tl.constexpr,
):
idx = tl.program_id(0)
idy = tl.program_id(1)
off_x = idx * BLOCK_X + tl.arange(0, BLOCK_X)
off_y = idy * BLOCK_Y + tl.arange(0, BLOCK_Y//16) * 16
off = off_x[:, None] * shape_y + off_y[None, :]
for _ in range(0, 1, 1):
got = tl.reshape(rand8bit(seed, off), [BLOCK_X, BLOCK_Y])
v = tl.make_block_ptr(
base=rand_v,
shape=(shape_x, shape_y),
strides=(shape_y, 1),
offsets=(idx * BLOCK_X, idy * BLOCK_Y),
block_shape=(BLOCK_X, BLOCK_Y),
order=(0, 1),
)
tl.store(v, got.to(tl.int16))
shape_x, shape_y = 2048, 4096
seed = 1991
grid = lambda META: (triton.cdiv(shape_x, META["BLOCK_X"]), triton.cdiv(shape_y, META["BLOCK_Y"]))
rand_v_1 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
gen_rand[grid](seed, rand_v_1, shape_x, shape_y,
BLOCK_X=128, BLOCK_Y=32)
rand_v_2 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda")
print("diff before:", (rand_v_1 - rand_v_2).abs().sum())
gen_rand[grid](seed, rand_v_2, shape_x, shape_y,
BLOCK_X=32, BLOCK_Y=128)
print("diff after:", (rand_v_1 - rand_v_2).abs().sum())
print(rand_v_2)
Yeah, I don't see why we shouldn't be able to do this with only one layout
transformation, during reshape
.
On Wed, May 29, 2024 at 10:38 AM bxyb @.***> wrote:
Hmm, are there any specific constraints on the layout of the operands to be joined?
See the MLIR verifier for the join op.
From your perspective, was this layout switch here (for tree-like join) unnecessary?
I agree with you that we should probably only need to do it once, at the end. But the reproducer is not a full repro (the Python function returns a tensor, which top-level JIT functions do not), and the IR is after a pass where we would expect to see lots of unnecessary converts, so I can't say anything for sure.
here is the whole repro.
import tritonfrom triton import language as tl @triton.jitdef rand8bit(seed, offset): a, b, c, d = tl.randint4x(seed, offset, 1) aa, ab, ac, ad, \ ba, bb, bc, bd, \ ca, cb, cc, cd, \ da, db, dc, dd \ = tl.inline_asm_elementwise( asm=""" { // Unpack
a
intoai
. .reg .b8 tmp<4>; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $16; cvt.u16.u8 $0, tmp0; cvt.u16.u8 $1, tmp1; cvt.u16.u8 $2, tmp2; cvt.u16.u8 $3, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $17; cvt.u16.u8 $4, tmp0; cvt.u16.u8 $5, tmp1; cvt.u16.u8 $6, tmp2; cvt.u16.u8 $7, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $18; cvt.u16.u8 $8, tmp0; cvt.u16.u8 $9, tmp1; cvt.u16.u8 $10, tmp2; cvt.u16.u8 $11, tmp3; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $19; cvt.u16.u8 $12, tmp0; cvt.u16.u8 $13, tmp1; cvt.u16.u8 $14, tmp2; cvt.u16.u8 $15, tmp3; } """, constraints=( "=r,=r,=r,=r," "=r,=r,=r,=r," "=r,=r,=r,=r," "=r,=r,=r,=r," "r,r,r,r"), args=[a, b, c, d], dtype=[ tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, tl.uint16, ], is_pure=True, pack=1, ) return tl.join( tl.join( tl.join(tl.join(aa, ab), tl.join(ac, ad)), tl.join(tl.join(ba, bb), tl.join(bc, bd)) ), tl.join( tl.join(tl.join(ca, cb), tl.join(cc, cd)), tl.join(tl.join(da, db), tl.join(dc, dd)) ), ) if name == "main": import torchtest triton mask
@triton.jit def gen_rand( seed, rand_v, shape_x, shape_y, BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr, ): idx = tl.program_id(0) idy = tl.program_id(1) off_x = idx * BLOCK_X + tl.arange(0, BLOCK_X) off_y = idy * BLOCK_Y + tl.arange(0, BLOCK_Y//16) * 16 off = off_x[:, None] * shape_y + off_y[None, :] for _ in range(0, 1, 1): got = tl.reshape(rand8bit(seed, off), [BLOCK_X, BLOCK_Y]) v = tl.make_block_ptr( base=rand_v, shape=(shape_x, shape_y), strides=(shape_y, 1), offsets=(idx * BLOCK_X, idy * BLOCK_Y), block_shape=(BLOCK_X, BLOCK_Y), order=(0, 1), ) tl.store(v, got.to(tl.int16)) shape_x, shape_y = 2048, 4096 seed = 1991 grid = lambda META: (triton.cdiv(shape_x, META["BLOCK_X"]), triton.cdiv(shape_y, META["BLOCK_Y"])) rand_v_1 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda") gen_rand[grid](seed, rand_v_1, shape_x, shape_y, BLOCK_X=128, BLOCK_Y=32) rand_v_2 = torch.empty([shape_x, shape_y], dtype=torch.int16, device="cuda") print("diff before:", (rand_v_1 - rand_v_2).abs().sum()) gen_rand[grid](seed, rand_v_2, shape_x, shape_y, BLOCK_X=32, BLOCK_Y=128) print("diff after:", (rand_v_1 - rand_v_2).abs().sum()) print(rand_v_2)
— Reply to this email directly, view it on GitHub https://github.com/triton-lang/triton/issues/4030#issuecomment-2137938274, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABEZBZCCXNRYRJEWGZYL7TZEYHALAVCNFSM6AAAAABIOXPNESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXHEZTQMRXGQ . You are receiving this because you were mentioned.Message ID: @.***>
When using test.txt (as mlir input) to test tritongpu-remove-layout-conversions, all convert_layout ops could be removed. But with the original input (commented block settings in the test.txt), it could only remove the blocked 5, 6 (directly join blocked 4 to blocked 6). It could not remove all the layout transforms due to the final block setting was fixed. (sizePerThread=[1, 1, 1, 2, 2, 2], could not join 16 elements without transform)
Why It's 8 per thread? TritonGPUCoalesce set it.
So, if i only take 8 elments to join, it performs as expect (no layout transfer).
I'm trying to take use of the whole randint4x results (4 int32 rand number) and generate 4x 8bit rand number from each 1x 32bit rand number. (Similar trick as dropout done by FlashAttention)
And I find out the generated mlir are quite inefficient. There are multiple times layout transfer for the joined tensor.
Focus on join part, the generated mlir before TritonGPUCoalesce (after ConvertTritonToTritonGPU):
The final generated ptx / sass reserve the multiple layout transfer via share memory store and load.
Could the tree-like join transfer the layout only once at the end? I'm interested in improving this, but I haven't been able to identify which part of the
TritonToTritonGPUPass
inserts theconvert_layout
operation.