triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.03k stars 1.43k forks source link

Hitting an assertion in `RemoveLayoutConversions` Pass. Relevant for both cuda and hip backends. #4178

Closed ravil-mobile closed 4 weeks ago

ravil-mobile commented 1 month ago

Hi all,

I am working on a kernel which hits an assertion in RemoveLayoutConversions pass during the IR rewrite (the latest main branch). The bug is common for both cuda and hip backends.

Reproducible

Here is a small reproducible example (note, the code doesn't make sense because many parts were removed and many loop-bounds were simplified)

from triton.backends.compiler import GPUTarget
import triton
import triton.language as tl
import argparse

@triton.jit()
def kernel(
        A, B, C, P,
        M, N, num_sms,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    pid = tl.program_id(0)
    acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
    #acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

    for i in range(3):
        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        rk = tl.arange(0, BLOCK_SIZE_K)
        A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak
        B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk
        acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
        #acc = acc * 0.0

        for current_iter in range(8):
            a = tl.load(A_BASE)
            b = tl.load(B_BASE)
            acc += tl.dot(a, b)
            A_BASE += BLOCK_SIZE_K * stride_ak
            B_BASE += BLOCK_SIZE_K * stride_bk

        next_pid = pid + 1
        while (next_pid < num_sms):
            rm1 = tl.arange(0, BLOCK_SIZE_M)
            rn1 = tl.arange(0, BLOCK_SIZE_N)
            P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
            acc1 = tl.load(P_)
            acc += acc1
            next_pid += 1

        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        tl.store(C_, acc, mask=mask)

signature = {'A': '*fp16', 'B': '*fp16', 'C': '*fp16', 'P': '*fp32',
             'M': 'i32', 'N': 'i32',
             'num_sms': 'i32', 'stride_am': 'i32', 'stride_ak': 'i32',
             'stride_bk': 'i32', 'stride_bn': 'i32', 'stride_cm': 'i32',
             'stride_cn': 'i32'}

constants = {'stride_ak': 1, 'stride_bk': 1,
             'stride_cn': 1, 'BLOCK_SIZE_M': 256,
             'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}

parser = argparse.ArgumentParser(
  prog="persistent stream-k gemm",
  description="",
  allow_abbrev=False,
)

parser.add_argument("-b", "--backend", choices=['cuda', 'hip'],
  default='hip',
  help="backend")
args = parser.parse_args()

if args.backend == 'cuda':
  curr_target = GPUTarget("cuda", 90, 32)
elif args.backend == 'hip':
  curr_target = GPUTarget("hip", 'gfx942', 64)

print(f'{curr_target=}')
src = triton.compiler.ASTSource(fn=kernel, signature=signature, constants=constants)
k = triton.compile(src, target=curr_target)

And here is the error:

python: /home/ravil/work/triton/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp:517: Value mlir::triton::gpu::(anonymous namespace)::LayoutPropagation::getValueAs(Value, Attribute): Asse
rtion `rewrittenValue' failed.                                                                                                                                                                                
 #0 0x00007f895939ad07 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) Signals.cpp:0:0                                                                                                                    
 #1 0x00007f895939878c llvm::sys::RunSignalHandlers() Signals.cpp:0:0                                                                                                                                         
 #2 0x00007f895939b3df SignalHandler(int) Signals.cpp:0:0                                                                                                                                                     
 #3 0x00007f895a20c520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)                                                                                                                                              
 #4 0x00007f895a2609fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76                                                                                                                              
 #5 0x00007f895a2609fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10                                                                                                                                    
 #6 0x00007f895a2609fc pthread_kill ./nptl/pthread_kill.c:89:10                                                                                                                                               
 #7 0x00007f895a20c476 gsignal ./signal/../sysdeps/posix/raise.c:27:6                                                                                                                                         
 #8 0x00007f895a1f27f3 abort ./stdlib/abort.c:81:7                                                                                                                                                            
 #9 0x00007f895a1f271b _nl_load_domain ./intl/loadmsgcat.c:1177:9                                                                                                                                             
#10 0x00007f895a203e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)                                                                                                                                              
#11 0x00007f89553e7c21 mlir::triton::gpu::(anonymous namespace)::LayoutPropagation::getValueAs(mlir::Value, mlir::Attribute) /home/ravil/work/triton/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConve
rsions.cpp:518:47                                                                                                                                                                                             
#12 0x00007f89553e6958 mlir::triton::gpu::(anonymous namespace)::LayoutPropagation::rewriteRegion(mlir::Region&) /home/ravil/work/triton/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp:4
85:30                                                                                                                                                                                                         
#13 0x00007f89553e3876 mlir::triton::gpu::(anonymous namespace)::LayoutPropagation::rewrite() /home/ravil/work/triton/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp:434:74              
#14 0x00007f89553ee0c1 mlir::triton::gpu::TritonGPURemoveLayoutConversionsPass::runOnOperation()::'lambda'(mlir::triton::FuncOp)::operator()(mlir::triton::FuncOp) const /home/ravil/work/triton/triton/lib/Di
alect/TritonGPU/Transforms/RemoveLayoutConversions.cpp:1218:5                                                                                                                                                 
#15 0x00007f89553ee053 std::enable_if<!llvm::is_one_of<mlir::triton::FuncOp, mlir::Operation*, mlir::Region*, mlir::Block*>::value && std::is_same<void, void>::value, void>::type mlir::detail::walk<(mlir::W
alkOrder)1, mlir::ForwardIterator, mlir::triton::gpu::TritonGPURemoveLayoutConversionsPass::runOnOperation()::'lambda'(mlir::triton::FuncOp), mlir::triton::FuncOp, void>(mlir::Operation*, mlir::triton::gpu:
:TritonGPURemoveLayoutConversionsPass::runOnOperation()::'lambda'(mlir::triton::FuncOp)&&)::'lambda'(mlir::Operation*)::operator()(mlir::Operation*) const /home/ravil/.triton/llvm/llvm-e4790ce2-ubuntu-x64/i
nclude/mlir/IR/Visitors.h:340:3 
...

Fix 1

One can spot comments in the code snippet above. These are basically the fix. For example, the following code successfully compiles

@triton.jit()
def kernel(
        A, B, C, P,
        M, N, num_sms,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    pid = tl.program_id(0)
    acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

    for i in range(3):
        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        rk = tl.arange(0, BLOCK_SIZE_K)
        A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak
        B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk
        acc = acc * 0.0

        for current_iter in range(8):
            a = tl.load(A_BASE)
            b = tl.load(B_BASE)
            acc += tl.dot(a, b)
            A_BASE += BLOCK_SIZE_K * stride_ak
            B_BASE += BLOCK_SIZE_K * stride_bk

        next_pid = pid + 1
        while (next_pid < num_sms):
            rm1 = tl.arange(0, BLOCK_SIZE_M)
            rn1 = tl.arange(0, BLOCK_SIZE_N)
            P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
            acc1 = tl.load(P_)
            acc += acc1
            next_pid += 1

        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        tl.store(C_, acc, mask=mask)

Fix 2

The original code can be also fixed by replacing the last while-loop with the for-loop. I mean the logic became difference but, at least, I can compile it.

@triton.jit()
def kernel(
        A, B, C, P,
        M, N, num_sms,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    pid = tl.program_id(0)
    acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

    for i in range(3):
        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        rk = tl.arange(0, BLOCK_SIZE_K)
        A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak
        B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk
        acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

        for current_iter in range(8):
            a = tl.load(A_BASE)
            b = tl.load(B_BASE)
            acc += tl.dot(a, b)
            A_BASE += BLOCK_SIZE_K * stride_ak
            B_BASE += BLOCK_SIZE_K * stride_bk

        next_pid = pid + 1
        for ii in range(10):
            rm1 = tl.arange(0, BLOCK_SIZE_M)
            rn1 = tl.arange(0, BLOCK_SIZE_N)
            P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
            acc1 = tl.load(P_)
            acc += acc1
            next_pid += 1

        rm = BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        rn = BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        tl.store(C_, acc, mask=mask)

Reference to the source code

One can find the assertion in the source code here: https://github.com/triton-lang/triton/blob/a5b3783491a42a61ed5b8cb32a1178eb08e7b085/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#L503-L517

Comment

Note, rewrittenValue = rewriteMapping[{value, encodingPicked}]; results in rewrittenValue being equal to nullptr. I suspect that the rewriteWhileOp is not called for the corresponding while-loops in the source code.

I would appreciate if somebody could help me to fix the bug.

PS I suspect that something is going wrong in propagateToUsers method

ravil-mobile commented 1 month ago

MLIR Difference

Reproducible

// -----// IR Dump Before TritonGPURemoveLayoutConversions (tritongpu-remove-layout-conversions) ('builtin.module' operation) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked6 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked7 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#loc = loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0)
#loc1 = loc(unknown)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
  tt.func public @kernel(%arg0: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg1: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg2: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg3: !tt.ptr<f32> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg4: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg5: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg6: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg7: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg8: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg9: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0)) attributes {noinline = false} {
    %c32768_i32 = arith.constant 32768 : i32 loc(#loc1)
    %c8_i32 = arith.constant 8 : i32 loc(#loc1)
    %c3_i32 = arith.constant 3 : i32 loc(#loc1)
    %c0_i32 = arith.constant 0 : i32 loc(#loc1)
    %cst = arith.constant dense<128> : tensor<256x1xi32, #blocked> loc(#loc1)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #blocked1> loc(#loc1)
    %cst_1 = arith.constant dense<64> : tensor<64x128xi32, #blocked1> loc(#loc1)
    %cst_2 = arith.constant dense<64> : tensor<256x64xi32, #blocked2> loc(#loc1)
    %cst_3 = arith.constant dense<128> : tensor<128xi32, #blocked3> loc(#loc1)
    %cst_4 = arith.constant dense<256> : tensor<256xi32, #blocked3> loc(#loc1)
    %c1_i32 = arith.constant 1 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked3> loc(#loc3)
    %2 = arith.addi %1, %cst_4 : tensor<256xi32, #blocked3> loc(#loc4)
    %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> loc(#loc5)
    %4 = arith.addi %3, %cst_3 : tensor<128xi32, #blocked3> loc(#loc6)
    %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> loc(#loc7)
    %6 = triton_gpu.convert_layout %2 : tensor<256xi32, #blocked3> -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc8)
    %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi32, #blocked4> loc(#loc8)
    %8 = triton_gpu.convert_layout %7 : tensor<256x1xi32, #blocked4> -> tensor<256x1xi32, #blocked> loc(#loc9)
    %9 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked> loc(#loc9)
    %10 = arith.muli %8, %9 : tensor<256x1xi32, #blocked> loc(#loc9)
    %11 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked> loc(#loc10)
    %12 = tt.addptr %11, %10 : tensor<256x1x!tt.ptr<f16>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc10)
    %13 = triton_gpu.convert_layout %5 : tensor<64xi32, #blocked3> -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc11)
    %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x64xi32, #blocked5> loc(#loc11)
    %15 = triton_gpu.convert_layout %14 : tensor<1x64xi32, #blocked5> -> tensor<1x64xi32, #blocked2> loc(#loc12)
    %16 = tt.broadcast %12 : tensor<256x1x!tt.ptr<f16>, #blocked> -> tensor<256x64x!tt.ptr<f16>, #blocked> loc(#loc12)
    %17 = triton_gpu.convert_layout %16 : tensor<256x64x!tt.ptr<f16>, #blocked> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc12)
    %18 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<256x64xi32, #blocked2> loc(#loc12)
    %19 = tt.addptr %17, %18 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc12)
    %20 = tt.addptr %19, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc13)
    %21 = triton_gpu.convert_layout %5 : tensor<64xi32, #blocked3> -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc14)
    %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1xi32, #blocked4> loc(#loc14)
    %23 = triton_gpu.convert_layout %22 : tensor<64x1xi32, #blocked4> -> tensor<64x1xi32, #blocked> loc(#loc15)
    %24 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked> loc(#loc15)
    %25 = tt.addptr %24, %23 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked> loc(#loc15)
    %26 = triton_gpu.convert_layout %4 : tensor<128xi32, #blocked3> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc16)
    %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi32, #blocked5> loc(#loc16)
    %28 = triton_gpu.convert_layout %27 : tensor<1x128xi32, #blocked5> -> tensor<1x128xi32, #blocked1> loc(#loc17)
    %29 = tt.splat %arg8 : i32 -> tensor<1x128xi32, #blocked1> loc(#loc17)
    %30 = arith.muli %28, %29 : tensor<1x128xi32, #blocked1> loc(#loc17)
    %31 = tt.broadcast %25 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked> loc(#loc18)
    %32 = triton_gpu.convert_layout %31 : tensor<64x128x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc18)
    %33 = tt.broadcast %30 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> loc(#loc18)
    %34 = tt.addptr %32, %33 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc18)
    %35 = tt.addptr %34, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc19)
    %36 = arith.addi %0, %c1_i32 : i32 loc(#loc20)
    %37 = triton_gpu.convert_layout %1 : tensor<256xi32, #blocked3> -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc21)
    %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi32, #blocked4> loc(#loc21)
    %39 = triton_gpu.convert_layout %38 : tensor<256x1xi32, #blocked4> -> tensor<256x1xi32, #blocked> loc(#loc22)
    %40 = arith.muli %39, %cst : tensor<256x1xi32, #blocked> loc(#loc22)
    %41 = triton_gpu.convert_layout %3 : tensor<128xi32, #blocked3> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc23)
    %42 = tt.expand_dims %41 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi32, #blocked5> loc(#loc23)
    %43 = triton_gpu.convert_layout %42 : tensor<1x128xi32, #blocked5> -> tensor<1x128xi32, #blocked1> loc(#loc24)
    %44 = tt.broadcast %43 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> loc(#loc24)
    %45 = tt.splat %arg9 : i32 -> tensor<256x1xi32, #blocked> loc(#loc25)
    %46 = arith.muli %8, %45 : tensor<256x1xi32, #blocked> loc(#loc25)
    %47 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked> loc(#loc26)
    %48 = tt.addptr %47, %46 : tensor<256x1x!tt.ptr<f16>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc26)
    %49 = tt.broadcast %48 : tensor<256x1x!tt.ptr<f16>, #blocked> -> tensor<256x128x!tt.ptr<f16>, #blocked> loc(#loc27)
    %50 = triton_gpu.convert_layout %49 : tensor<256x128x!tt.ptr<f16>, #blocked> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc27)
    %51 = tt.broadcast %28 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> loc(#loc27)
    %52 = tt.addptr %50, %51 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc27)
    %53 = tt.splat %arg4 : i32 -> tensor<256xi32, #blocked3> loc(#loc28)
    %54 = arith.cmpi slt, %2, %53 : tensor<256xi32, #blocked3> loc(#loc28)
    %55 = triton_gpu.convert_layout %54 : tensor<256xi1, #blocked3> -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc29)
    %56 = tt.expand_dims %55 {axis = 1 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi1, #blocked4> loc(#loc29)
    %57 = triton_gpu.convert_layout %56 : tensor<256x1xi1, #blocked4> -> tensor<256x1xi1, #blocked> loc(#loc30)
    %58 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked3> loc(#loc31)
    %59 = arith.cmpi slt, %4, %58 : tensor<128xi32, #blocked3> loc(#loc31)
    %60 = triton_gpu.convert_layout %59 : tensor<128xi1, #blocked3> -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc32)
    %61 = tt.expand_dims %60 {axis = 0 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi1, #blocked5> loc(#loc32)
    %62 = triton_gpu.convert_layout %61 : tensor<1x128xi1, #blocked5> -> tensor<1x128xi1, #blocked1> loc(#loc30)
    %63 = tt.broadcast %57 : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> loc(#loc30)
    %64 = triton_gpu.convert_layout %63 : tensor<256x128xi1, #blocked> -> tensor<256x128xi1, #blocked1> loc(#loc30)
    %65 = tt.broadcast %62 : tensor<1x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc30)
    %66 = arith.andi %64, %65 : tensor<256x128xi1, #blocked1> loc(#loc30)
    scf.for %arg10 = %c0_i32 to %c3_i32 step %c1_i32  : i32 {
      %67:3 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg13 = %20, %arg14 = %35) -> (tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1>)  : i32 {
        %73 = triton_gpu.convert_layout %arg13 : tensor<256x64x!tt.ptr<f16>, #blocked2> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc35)
        %74 = tt.load %73 : tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc35)
        %75 = triton_gpu.convert_layout %arg14 : tensor<64x128x!tt.ptr<f16>, #blocked1> -> tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc36)
        %76 = tt.load %75 : tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc36)
        %77 = triton_gpu.convert_layout %76 : tensor<64x128xf16, #blocked6> -> tensor<64x128xf16, #blocked1> loc(#loc36)
        %78 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> loc(#loc35)
        %79 = triton_gpu.convert_layout %77 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> loc(#loc36)
        %80 = triton_gpu.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> loc(#loc37)
        %81 = tt.dot %78, %79, %80 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> loc(#loc38)
        %82 = triton_gpu.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> loc(#loc39)
        %83 = tt.addptr %arg13, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc40)
        %84 = tt.addptr %arg14, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc41)
        scf.yield %82, %83, %84 : tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc39)
      } loc(#loc34)
      %68:2 = scf.while (%arg11 = %67#0, %arg12 = %36) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) {
        %73 = arith.cmpi slt, %arg12, %arg6 : i32 loc(#loc43)
        scf.condition(%73) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 loc(#loc43)
      } do {
      ^bb0(%arg11: tensor<256x128xf32, #blocked1> loc(unknown), %arg12: i32 loc(unknown)):
        %73 = arith.muli %arg12, %c32768_i32 : i32 loc(#loc44)
        %74 = tt.addptr %arg3, %73 : !tt.ptr<f32>, i32 loc(#loc45)
        %75 = tt.splat %74 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked> loc(#loc46)
        %76 = tt.addptr %75, %40 : tensor<256x1x!tt.ptr<f32>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc46)
        %77 = tt.broadcast %76 : tensor<256x1x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked> loc(#loc24)
        %78 = triton_gpu.convert_layout %77 : tensor<256x128x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc24)
        %79 = tt.addptr %78, %44 : tensor<256x128x!tt.ptr<f32>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc24)
        %80 = triton_gpu.convert_layout %79 : tensor<256x128x!tt.ptr<f32>, #blocked1> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
        %81 = tt.load %80 : tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
        %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> loc(#loc48)
        %83 = arith.addi %arg12, %c1_i32 : i32 loc(#loc49)
        scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 loc(#loc50)
      } loc(#loc42)
      %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> loc(#loc51)
      %70 = triton_gpu.convert_layout %52 : tensor<256x128x!tt.ptr<f16>, #blocked1> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
      %71 = triton_gpu.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> loc(#loc51)
      %72 = triton_gpu.convert_layout %66 : tensor<256x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc51)
      tt.store %70, %71, %72 : tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
    } loc(#loc33)
    tt.return loc(#loc52)
  } loc(#loc)
} loc(#loc)

Fix 1

// -----// IR Dump Before TritonGPURemoveLayoutConversions (tritongpu-remove-layout-conversions) ('builtin.module' operation) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked6 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked7 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#loc = loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0)
#loc1 = loc(unknown)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
  tt.func public @kernel(%arg0: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg1: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg2: !tt.ptr<f16> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg3: !tt.ptr<f32> loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg4: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg5: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg6: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg7: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg8: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0), %arg9: i32 loc("/home/ravil/work/triton/triton/.vscode/streamk/./reproducible.py":8:0)) attributes {noinline = false} {
    %c32768_i32 = arith.constant 32768 : i32 loc(#loc1)
    %c8_i32 = arith.constant 8 : i32 loc(#loc1)
    %c3_i32 = arith.constant 3 : i32 loc(#loc1)
    %c0_i32 = arith.constant 0 : i32 loc(#loc1)
    %cst = arith.constant dense<128> : tensor<256x1xi32, #blocked> loc(#loc1)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #blocked1> loc(#loc1)
    %cst_1 = arith.constant dense<64> : tensor<64x128xi32, #blocked1> loc(#loc1)
    %cst_2 = arith.constant dense<64> : tensor<256x64xi32, #blocked2> loc(#loc1)
    %cst_3 = arith.constant dense<128> : tensor<128xi32, #blocked3> loc(#loc1)
    %cst_4 = arith.constant dense<256> : tensor<256xi32, #blocked3> loc(#loc1)
    %c1_i32 = arith.constant 1 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked3> loc(#loc3)
    %2 = arith.addi %1, %cst_4 : tensor<256xi32, #blocked3> loc(#loc4)
    %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> loc(#loc5)
    %4 = arith.addi %3, %cst_3 : tensor<128xi32, #blocked3> loc(#loc6)
    %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> loc(#loc7)
    %6 = triton_gpu.convert_layout %2 : tensor<256xi32, #blocked3> -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc8)
    %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi32, #blocked4> loc(#loc8)
    %8 = triton_gpu.convert_layout %7 : tensor<256x1xi32, #blocked4> -> tensor<256x1xi32, #blocked> loc(#loc9)
    %9 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked> loc(#loc9)
    %10 = arith.muli %8, %9 : tensor<256x1xi32, #blocked> loc(#loc9)
    %11 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked> loc(#loc10)
    %12 = tt.addptr %11, %10 : tensor<256x1x!tt.ptr<f16>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc10)
    %13 = triton_gpu.convert_layout %5 : tensor<64xi32, #blocked3> -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc11)
    %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x64xi32, #blocked5> loc(#loc11)
    %15 = triton_gpu.convert_layout %14 : tensor<1x64xi32, #blocked5> -> tensor<1x64xi32, #blocked2> loc(#loc12)
    %16 = tt.broadcast %12 : tensor<256x1x!tt.ptr<f16>, #blocked> -> tensor<256x64x!tt.ptr<f16>, #blocked> loc(#loc12)
    %17 = triton_gpu.convert_layout %16 : tensor<256x64x!tt.ptr<f16>, #blocked> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc12)
    %18 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<256x64xi32, #blocked2> loc(#loc12)
    %19 = tt.addptr %17, %18 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc12)
    %20 = tt.addptr %19, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc13)
    %21 = triton_gpu.convert_layout %5 : tensor<64xi32, #blocked3> -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc14)
    %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1xi32, #blocked4> loc(#loc14)
    %23 = triton_gpu.convert_layout %22 : tensor<64x1xi32, #blocked4> -> tensor<64x1xi32, #blocked> loc(#loc15)
    %24 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked> loc(#loc15)
    %25 = tt.addptr %24, %23 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked> loc(#loc15)
    %26 = triton_gpu.convert_layout %4 : tensor<128xi32, #blocked3> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc16)
    %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi32, #blocked5> loc(#loc16)
    %28 = triton_gpu.convert_layout %27 : tensor<1x128xi32, #blocked5> -> tensor<1x128xi32, #blocked1> loc(#loc17)
    %29 = tt.splat %arg8 : i32 -> tensor<1x128xi32, #blocked1> loc(#loc17)
    %30 = arith.muli %28, %29 : tensor<1x128xi32, #blocked1> loc(#loc17)
    %31 = tt.broadcast %25 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked> loc(#loc18)
    %32 = triton_gpu.convert_layout %31 : tensor<64x128x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc18)
    %33 = tt.broadcast %30 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> loc(#loc18)
    %34 = tt.addptr %32, %33 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc18)
    %35 = tt.addptr %34, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc19)
    %36 = arith.addi %0, %c1_i32 : i32 loc(#loc20)
    %37 = triton_gpu.convert_layout %1 : tensor<256xi32, #blocked3> -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc21)
    %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi32, #blocked4> loc(#loc21)
    %39 = triton_gpu.convert_layout %38 : tensor<256x1xi32, #blocked4> -> tensor<256x1xi32, #blocked> loc(#loc22)
    %40 = arith.muli %39, %cst : tensor<256x1xi32, #blocked> loc(#loc22)
    %41 = triton_gpu.convert_layout %3 : tensor<128xi32, #blocked3> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc23)
    %42 = tt.expand_dims %41 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi32, #blocked5> loc(#loc23)
    %43 = triton_gpu.convert_layout %42 : tensor<1x128xi32, #blocked5> -> tensor<1x128xi32, #blocked1> loc(#loc24)
    %44 = tt.broadcast %43 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> loc(#loc24)
    %45 = tt.splat %arg9 : i32 -> tensor<256x1xi32, #blocked> loc(#loc25)
    %46 = arith.muli %8, %45 : tensor<256x1xi32, #blocked> loc(#loc25)
    %47 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked> loc(#loc26)
    %48 = tt.addptr %47, %46 : tensor<256x1x!tt.ptr<f16>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc26)
    %49 = tt.broadcast %48 : tensor<256x1x!tt.ptr<f16>, #blocked> -> tensor<256x128x!tt.ptr<f16>, #blocked> loc(#loc27)
    %50 = triton_gpu.convert_layout %49 : tensor<256x128x!tt.ptr<f16>, #blocked> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc27)
    %51 = tt.broadcast %28 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> loc(#loc27)
    %52 = tt.addptr %50, %51 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc27)
    %53 = tt.splat %arg4 : i32 -> tensor<256xi32, #blocked3> loc(#loc28)
    %54 = arith.cmpi slt, %2, %53 : tensor<256xi32, #blocked3> loc(#loc28)
    %55 = triton_gpu.convert_layout %54 : tensor<256xi1, #blocked3> -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> loc(#loc29)
    %56 = tt.expand_dims %55 {axis = 1 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1xi1, #blocked4> loc(#loc29)
    %57 = triton_gpu.convert_layout %56 : tensor<256x1xi1, #blocked4> -> tensor<256x1xi1, #blocked> loc(#loc30)
    %58 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked3> loc(#loc31)
    %59 = arith.cmpi slt, %4, %58 : tensor<128xi32, #blocked3> loc(#loc31)
    %60 = triton_gpu.convert_layout %59 : tensor<128xi1, #blocked3> -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> loc(#loc32)
    %61 = tt.expand_dims %60 {axis = 0 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi1, #blocked5> loc(#loc32)
    %62 = triton_gpu.convert_layout %61 : tensor<1x128xi1, #blocked5> -> tensor<1x128xi1, #blocked1> loc(#loc30)
    %63 = tt.broadcast %57 : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> loc(#loc30)
    %64 = triton_gpu.convert_layout %63 : tensor<256x128xi1, #blocked> -> tensor<256x128xi1, #blocked1> loc(#loc30)
    %65 = tt.broadcast %62 : tensor<1x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc30)
    %66 = arith.andi %64, %65 : tensor<256x128xi1, #blocked1> loc(#loc30)
    %67 = scf.for %arg10 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg11 = %cst_0) -> (tensor<256x128xf32, #blocked1>)  : i32 {
      %68 = arith.mulf %arg11, %cst_0 : tensor<256x128xf32, #blocked1> loc(#loc34)
      %69:3 = scf.for %arg12 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg13 = %68, %arg14 = %20, %arg15 = %35) -> (tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1>)  : i32 {
        %75 = triton_gpu.convert_layout %arg14 : tensor<256x64x!tt.ptr<f16>, #blocked2> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc36)
        %76 = tt.load %75 : tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc36)
        %77 = triton_gpu.convert_layout %arg15 : tensor<64x128x!tt.ptr<f16>, #blocked1> -> tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc37)
        %78 = tt.load %77 : tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc37)
        %79 = triton_gpu.convert_layout %78 : tensor<64x128xf16, #blocked6> -> tensor<64x128xf16, #blocked1> loc(#loc37)
        %80 = triton_gpu.convert_layout %76 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> loc(#loc36)
        %81 = triton_gpu.convert_layout %79 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> loc(#loc37)
        %82 = triton_gpu.convert_layout %arg13 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> loc(#loc34)
        %83 = tt.dot %80, %81, %82 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> loc(#loc38)
        %84 = triton_gpu.convert_layout %83 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> loc(#loc39)
        %85 = tt.addptr %arg14, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc40)
        %86 = tt.addptr %arg15, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc41)
        scf.yield %84, %85, %86 : tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc39)
      } loc(#loc35)
      %70:2 = scf.while (%arg12 = %69#0, %arg13 = %36) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) {
        %75 = arith.cmpi slt, %arg13, %arg6 : i32 loc(#loc43)
        scf.condition(%75) %arg12, %arg13 : tensor<256x128xf32, #blocked1>, i32 loc(#loc43)
      } do {
      ^bb0(%arg12: tensor<256x128xf32, #blocked1> loc(unknown), %arg13: i32 loc(unknown)):
        %75 = arith.muli %arg13, %c32768_i32 : i32 loc(#loc44)
        %76 = tt.addptr %arg3, %75 : !tt.ptr<f32>, i32 loc(#loc45)
        %77 = tt.splat %76 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked> loc(#loc46)
        %78 = tt.addptr %77, %40 : tensor<256x1x!tt.ptr<f32>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc46)
        %79 = tt.broadcast %78 : tensor<256x1x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked> loc(#loc24)
        %80 = triton_gpu.convert_layout %79 : tensor<256x128x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc24)
        %81 = tt.addptr %80, %44 : tensor<256x128x!tt.ptr<f32>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc24)
        %82 = triton_gpu.convert_layout %81 : tensor<256x128x!tt.ptr<f32>, #blocked1> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
        %83 = tt.load %82 : tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
        %84 = arith.addf %arg12, %83 : tensor<256x128xf32, #blocked1> loc(#loc48)
        %85 = arith.addi %arg13, %c1_i32 : i32 loc(#loc49)
        scf.yield %84, %85 : tensor<256x128xf32, #blocked1>, i32 loc(#loc50)
      } loc(#loc42)
      %71 = arith.truncf %70#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> loc(#loc51)
      %72 = triton_gpu.convert_layout %52 : tensor<256x128x!tt.ptr<f16>, #blocked1> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
      %73 = triton_gpu.convert_layout %71 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> loc(#loc51)
      %74 = triton_gpu.convert_layout %66 : tensor<256x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc51)
      tt.store %72, %73, %74 : tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
      scf.yield %70#0 : tensor<256x128xf32, #blocked1> loc(#loc52)
    } loc(#loc33)
    tt.return loc(#loc53)
  } loc(#loc)
} loc(#loc)

Diff: Reproducible vs Fix 1

92,110c92,111
<     scf.for %arg10 = %c0_i32 to %c3_i32 step %c1_i32  : i32 {
<       %67:3 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg13 = %20, %arg14 = %35) -> (tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1>)  : i32 {
<         %73 = triton_gpu.convert_layout %arg13 : tensor<256x64x!tt.ptr<f16>, #blocked2> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc35)
<         %74 = tt.load %73 : tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc35)
<         %75 = triton_gpu.convert_layout %arg14 : tensor<64x128x!tt.ptr<f16>, #blocked1> -> tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc36)
<         %76 = tt.load %75 : tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc36)
<         %77 = triton_gpu.convert_layout %76 : tensor<64x128xf16, #blocked6> -> tensor<64x128xf16, #blocked1> loc(#loc36)
<         %78 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> loc(#loc35)
<         %79 = triton_gpu.convert_layout %77 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> loc(#loc36)
<         %80 = triton_gpu.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> loc(#loc37)
<         %81 = tt.dot %78, %79, %80 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> loc(#loc38)
<         %82 = triton_gpu.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> loc(#loc39)
<         %83 = tt.addptr %arg13, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc40)
<         %84 = tt.addptr %arg14, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc41)
<         scf.yield %82, %83, %84 : tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc39)
<       } loc(#loc34)
<       %68:2 = scf.while (%arg11 = %67#0, %arg12 = %36) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) {
<         %73 = arith.cmpi slt, %arg12, %arg6 : i32 loc(#loc43)
<         scf.condition(%73) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 loc(#loc43)
---
>     %67 = scf.for %arg10 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg11 = %cst_0) -> (tensor<256x128xf32, #blocked1>)  : i32 {
>       %68 = arith.mulf %arg11, %cst_0 : tensor<256x128xf32, #blocked1> loc(#loc34)
>       %69:3 = scf.for %arg12 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg13 = %68, %arg14 = %20, %arg15 = %35) -> (tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1>)  : i32 {
>         %75 = triton_gpu.convert_layout %arg14 : tensor<256x64x!tt.ptr<f16>, #blocked2> -> tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc36)
>         %76 = tt.load %75 : tensor<256x64x!tt.ptr<f16>, #blocked2> loc(#loc36)
>         %77 = triton_gpu.convert_layout %arg15 : tensor<64x128x!tt.ptr<f16>, #blocked1> -> tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc37)
>         %78 = tt.load %77 : tensor<64x128x!tt.ptr<f16>, #blocked6> loc(#loc37)
>         %79 = triton_gpu.convert_layout %78 : tensor<64x128xf16, #blocked6> -> tensor<64x128xf16, #blocked1> loc(#loc37)
>         %80 = triton_gpu.convert_layout %76 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> loc(#loc36)
>         %81 = triton_gpu.convert_layout %79 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> loc(#loc37)
>         %82 = triton_gpu.convert_layout %arg13 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> loc(#loc34)
>         %83 = tt.dot %80, %81, %82 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> loc(#loc38)
>         %84 = triton_gpu.convert_layout %83 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> loc(#loc39)
>         %85 = tt.addptr %arg14, %cst_2 : tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<256x64xi32, #blocked2> loc(#loc40)
>         %86 = tt.addptr %arg15, %cst_1 : tensor<64x128x!tt.ptr<f16>, #blocked1>, tensor<64x128xi32, #blocked1> loc(#loc41)
>         scf.yield %84, %85, %86 : tensor<256x128xf32, #blocked1>, tensor<256x64x!tt.ptr<f16>, #blocked2>, tensor<64x128x!tt.ptr<f16>, #blocked1> loc(#loc39)
>       } loc(#loc35)
>       %70:2 = scf.while (%arg12 = %69#0, %arg13 = %36) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) {
>         %75 = arith.cmpi slt, %arg13, %arg6 : i32 loc(#loc43)
>         scf.condition(%75) %arg12, %arg13 : tensor<256x128xf32, #blocked1>, i32 loc(#loc43)
112,124c113,125
<       ^bb0(%arg11: tensor<256x128xf32, #blocked1> loc(unknown), %arg12: i32 loc(unknown)):
<         %73 = arith.muli %arg12, %c32768_i32 : i32 loc(#loc44)
<         %74 = tt.addptr %arg3, %73 : !tt.ptr<f32>, i32 loc(#loc45)
<         %75 = tt.splat %74 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked> loc(#loc46)
<         %76 = tt.addptr %75, %40 : tensor<256x1x!tt.ptr<f32>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc46)
<         %77 = tt.broadcast %76 : tensor<256x1x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked> loc(#loc24)
<         %78 = triton_gpu.convert_layout %77 : tensor<256x128x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc24)
<         %79 = tt.addptr %78, %44 : tensor<256x128x!tt.ptr<f32>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc24)
<         %80 = triton_gpu.convert_layout %79 : tensor<256x128x!tt.ptr<f32>, #blocked1> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
<         %81 = tt.load %80 : tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
<         %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> loc(#loc48)
<         %83 = arith.addi %arg12, %c1_i32 : i32 loc(#loc49)
<         scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 loc(#loc50)
---
>       ^bb0(%arg12: tensor<256x128xf32, #blocked1> loc(unknown), %arg13: i32 loc(unknown)):
>         %75 = arith.muli %arg13, %c32768_i32 : i32 loc(#loc44)
>         %76 = tt.addptr %arg3, %75 : !tt.ptr<f32>, i32 loc(#loc45)
>         %77 = tt.splat %76 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked> loc(#loc46)
>         %78 = tt.addptr %77, %40 : tensor<256x1x!tt.ptr<f32>, #blocked>, tensor<256x1xi32, #blocked> loc(#loc46)
>         %79 = tt.broadcast %78 : tensor<256x1x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked> loc(#loc24)
>         %80 = triton_gpu.convert_layout %79 : tensor<256x128x!tt.ptr<f32>, #blocked> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc24)
>         %81 = tt.addptr %80, %44 : tensor<256x128x!tt.ptr<f32>, #blocked1>, tensor<256x128xi32, #blocked1> loc(#loc24)
>         %82 = triton_gpu.convert_layout %81 : tensor<256x128x!tt.ptr<f32>, #blocked1> -> tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
>         %83 = tt.load %82 : tensor<256x128x!tt.ptr<f32>, #blocked1> loc(#loc47)
>         %84 = arith.addf %arg12, %83 : tensor<256x128xf32, #blocked1> loc(#loc48)
>         %85 = arith.addi %arg13, %c1_i32 : i32 loc(#loc49)
>         scf.yield %84, %85 : tensor<256x128xf32, #blocked1>, i32 loc(#loc50)
126,130c127,132
<       %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> loc(#loc51)
<       %70 = triton_gpu.convert_layout %52 : tensor<256x128x!tt.ptr<f16>, #blocked1> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
<       %71 = triton_gpu.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> loc(#loc51)
<       %72 = triton_gpu.convert_layout %66 : tensor<256x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc51)
<       tt.store %70, %71, %72 : tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
---
>       %71 = arith.truncf %70#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> loc(#loc51)
>       %72 = triton_gpu.convert_layout %52 : tensor<256x128x!tt.ptr<f16>, #blocked1> -> tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
>       %73 = triton_gpu.convert_layout %71 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> loc(#loc51)
>       %74 = triton_gpu.convert_layout %66 : tensor<256x128xi1, #blocked1> -> tensor<256x128xi1, #blocked1> loc(#loc51)
>       tt.store %72, %73, %74 : tensor<256x128x!tt.ptr<f16>, #blocked1> loc(#loc51)
>       scf.yield %70#0 : tensor<256x128xf32, #blocked1> loc(#loc52)
132c134
<     tt.return loc(#loc52)
---
>     tt.return loc(#loc53)
ravil-mobile commented 1 month ago

LayoutPropagation Dump after layoutPropagation.propagateLayout()

Reproducible

Value: %74 = tt.load %73 : tensor<256x64x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %76 = tt.load %75 : tensor<64x128x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
--
Value: %81 = tt.dot %78, %79, %80 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> -> tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %81 = tt.load %80 : tensor<256x128x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
--

Fix 1


Value: %76 = tt.load %75 : tensor<256x64x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %78 = tt.load %77 : tensor<64x128x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
--
Value: %83 = tt.dot %80, %81, %82 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> -> tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %83 = tt.load %82 : tensor<256x128x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
--
Value: %84 = arith.addf %arg12, %83 : tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: <block argument> of type 'tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>' at index: 0 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %69:3 = scf.for %arg12 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg13 = %68, %arg14 = %20, %arg15 = %35) -> (tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>, tensor<256x64x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>>, tensor<64x128x!tt.ptr<f16>, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>)  : i32 {...} 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: <block argument> of type 'tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>' at index: 0 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %70:2 = scf.while (%arg12 = %69#0, %arg13 = %36) : (tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>, i32) -> (tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>, i32) {...} do {...} 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %67 = scf.for %arg10 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg11 = %cst_0) -> (tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>)  : i32 {...} 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: <block argument> of type 'tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>' at index: 1 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %71 = arith.truncf %70#0 : tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> to tensor<256x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %73 = triton_gpu.convert_layout %71 : tensor<256x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> -> tensor<256x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %68 = arith.mulf %arg11, %cst_0 : tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: <block argument> of type 'tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>>' at index: 1 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %82 = triton_gpu.convert_layout %arg13 : tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> -> tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %84 = triton_gpu.convert_layout %83 : tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>> -> tensor<256x128xf32, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
--
Value: %79 = triton_gpu.convert_layout %78 : tensor<64x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>> -> tensor<64x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
--
Value: %81 = triton_gpu.convert_layout %79 : tensor<64x128xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
--
Value: %80 = triton_gpu.convert_layout %76 : tensor<256x64xf16, #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>}>> 
 encoding:
#triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
--