intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
137 stars 41 forks source link

Hand rewrite LLVM IR for GEMM using SIMT instructions #526

Closed whitneywhtsang closed 7 months ago

whitneywhtsang commented 7 months ago

Create a LLVMIR file that uses DPAS, 2d block read/write, and prefetch instructions to get an estimate initial performance at SIMT path.

One could use compiler to generate the LLVM IR file, but not implement the analysis portion and hard code the transformation portion as needed.

chengjunlu commented 7 months ago

To generate a Triton GEMM kernel with 2D load and DPAS has been finished. The code branches: GENX LLVM. https://github.com/chengjunlu/llvm/tree/chengjun/genx Triton. https://github.com/intel/intel-xpu-backend-for-triton/tree/chengjun/llvm-target-dpas

The performance is not good.

matmul-performance:
      M     N     K  TRANS_A  TRANS_B      oneDNN     Triton
0  2048   512   512    False    False    3.076229   2.619895
1  2048  1024  1024    False    False   11.766426   7.293279
2  2048  2048  2048    False    False   38.950051  12.688430
3  2048  4096  4096    False    False  112.458204  15.564876

I will add the prefetch ops and test it again.

etiotto commented 7 months ago

@chengjunlu have you used 2D stores or only 2D loads in this experiment?

chengjunlu commented 7 months ago

@chengjunlu have you used 2D stores or only 2D loads in this experiment?

Only 2D loads for now. I will continue to enable the 2D store and 2D prefetching.

chengjunlu commented 7 months ago

I have enabled the 2D prefetch op. But it seems the final GEN binary doesn't have the 2D prefetching ops as expected.

There are prefetching GenISA in the optimized LLVM IR for the scalar compiler to gen the code.

// Prefetching in the pre-epilogue.
  %50 = ptrtoint half addrspace(1)* %43 to i64, !dbg !392
  call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %50, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
  %51 = sext i32 %26 to i64, !dbg !393
  %52 = getelementptr half, half addrspace(1)* %1, i64 %51, !dbg !393
  %53 = shl i32 %4, 1, !dbg !393
  %54 = add i32 %53, -1, !dbg !393
  %55 = add i32 %5, -1, !dbg !393
  %56 = ptrtoint half addrspace(1)* %52 to i64, !dbg !393
  call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %56, i32 %54, i32 %55, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !393
  %57 = icmp sgt i32 %5, 0, !dbg !394
  br i1 %57, label %.lr.ph, label %._crit_edge38, !dbg !394

._crit_edge38:                                    ; preds = %cond-add-join41
  %.pre = and i32 %27, 56, !dbg !395
  br label %121, !dbg !394

...

// Prefetching in the loop body.
 %79 = getelementptr half, half addrspace(1)* %43, i64 %78, !dbg !392
  %80 = ptrtoint half addrspace(1)* %79 to i64, !dbg !392
  call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %80, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
  %81 = shl i32 %77, 9, !dbg !393
  %82 = sext i32 %81 to i64, !dbg !393
  %83 = getelementptr half, half addrspace(1)* %52, i64 %82, !dbg !393
  %84 = ptrtoint half addrspace(1)* %83 to i64, !dbg !393
  call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %84, i32 %54, i32 %55, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !393
  %85 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %60, i32 %45, i32 %46, i32 1023, i32 %66, i32 %59, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #4, !dbg !392
  %86 = or i32 %66, 16, !dbg !392

There is no prefetching ops in the final Gen assemble from the first inst of the tt.load to the first 2D load.

// Line 86:  b = tl.load(b_tile_ptr)
(W)     mov (1|M0)               r4.0<1>:d     56:w                               {Compacted}        //  ALU pipe: int; $191
(W)     mov (1|M0)               r6.0<1>:d     48:w                               {Compacted}        //  ALU pipe: int; $193

// Line 85:  a = tl.load(a_tile_ptr)
(W)     shl (1|M0)               r1.8<1>:d     r5.2<0;1,0>:d     1:w                                 //  ALU pipe: int; $184

// Line 86:  b = tl.load(b_tile_ptr)
(W)     shl (1|M0)               r2.0<1>:d     r5.1<0;1,0>:d     1:w               {Compacted}       //  ALU pipe: int; $188

// Line 84:  for k in range(0, K, BLOCK_SIZE_K):
        mov (16|M0)              r94.0<1>:f    0x0:f                               {Compacted}       //  ALU pipe: float; $198
        mov (16|M0)              r116.0<1>:f   0x0:f                               {Compacted}       //  ALU pipe: float; $199
        mov (16|M0)              r95.0<1>:f    0x0:f                               {Compacted}       //  ALU pipe: float; $200
        mov (16|M0)              r115.0<1>:f   0x0:f                               {Compacted}       //  ALU pipe: float; $201
        mov (16|M0)              r96.0<1>:f    0x0:f                               {Compacted}       //  ALU pipe: float; $202
        mov (16|M0)              r114.0<1>:f   0x0:f                               {Compacted}       //  ALU pipe: float; $203
        mov (16|M0)              r97.0<1>:f    0x0:f                               {Compacted}       //  ALU pipe: float; $204
        mov (16|M0)              r113.0<1>:f   0x0:f                               {Compacted}       //  ALU pipe: float; $205

// Line 86:  b = tl.load(b_tile_ptr)
        bfn.(s0&s1|s2) (16|M0)   r93.0<1>:ud   r124.0<1;0>:ud    r4.0<0;0>:ud      r125.1<0>:ud     {I@4} //  ALU pipe: int; $192
        bfn.(s0&s1|s2) (16|M0)   r117.0<1>:ud  r51.0<1;0>:ud     r6.0<0;0>:ud      r125.0<0>:ud     {I@4} //  ALU pipe: int; $194

// Line 85:  a = tl.load(a_tile_ptr)
(W)     add (1|M0)               r126.9<1>:d   r5.0<0;1,0>:d     -1:w                                //  ALU pipe: int; $186

// Line 86:  b = tl.load(b_tile_ptr)
(W)     add (1|M0)               r50.13<1>:d   r5.2<0;1,0>:d     -1:w                                //  ALU pipe: int; $190

// Line 84:  for k in range(0, K, BLOCK_SIZE_K):
(W)     mov (2|M0)               r50.8<1>:d    0:w                               {Compacted}         //  ALU pipe: int; $196
(W)     mov (1|M0)               r50.4<1>:d    0:w                               {Compacted}         //  ALU pipe: int; $206

// Line 85:  a = tl.load(a_tile_ptr)
(W)     add (1|M0)               r126.8<1>:d   r1.8<0;1,0>:d     -1:w               {I@7}            //  ALU pipe: int; $185

// Line 86:  b = tl.load(b_tile_ptr)
(W)     add (1|M0)               r50.12<1>:d   r2.0<0;1,0>:d     -1:w               {I@7}            //  ALU pipe: int; $189
// B021: Preds:{B022, B020},  Succs:{B022, B023}
_0_047:

// Line 85:  a = tl.load(a_tile_ptr)
(W)     mov (1|M0)               r2.0<1>:uq    r4.5<0;1,0>:q                                         //  ALU pipe: int; $209
(W)     mov (2|M0)               r2.2<1>:ud    r126.8<1;1,0>:d                  {I@3}                //  ALU pipe: int; $209
(W)     mov (1|M0)               r2.4<1>:ud    1023:w                                                //  ALU pipe: int; $209
(W)     mov (1|M0)               r2.5<1>:ud    r50.9<0;1,0>:d                                        //  ALU pipe: int; $209
(W)     mov (1|M0)               r2.6<1>:ud    r93.0<0;1,0>:d                                        //  ALU pipe: int; $209
(W)     mov (1|M0)               r2.7<1>:ud    0x70F:uw                                              //  ALU pipe: int; $209

// Line 86:  b = tl.load(b_tile_ptr)
(W)     mov (1|M0)               r29.0<1>:uq   r4.6<0;1,0>:q                                         //  ALU pipe: int; $217
(W)     mov (2|M0)               r29.2<1>:ud   r50.12<1;1,0>:d                  {I@7}                //  ALU pipe: int; $217
(W)     mov (1|M0)               r29.4<1>:ud   1023:w                                                //  ALU pipe: int; $217
(W)     mov (1|M0)               r29.5<1>:ud   r117.0<0;1,0>:d                                       //  ALU pipe: int; $217
(W)     mov (1|M0)               r29.6<1>:ud   r50.8<0;1,0>:d                                        //  ALU pipe: int; $217
(W)     mov (1|M0)               r29.7<1>:ud   0xF0F:uw                                              //  ALU pipe: int; $217

// Line 85:  a = tl.load(a_tile_ptr)
        load_block2d.ugm.d16.a64 (16|M0)  r7:4  [r2:1]             {I@7,$4} // ex_desc:0x0; desc:0x2400203 // $209

I will debug the issue.

The performance doesn't change because the prefetch ops seems not in the final binary.

matmul-performance:
      M     N     K  TRANS_A  TRANS_B      oneDNN     Triton
0  2048   512   512    False    False    3.024580   2.580859
1  2048  1024  1024    False    False   11.607216   7.217307
2  2048  2048  2048    False    False   38.595390  12.376777
3  2048  4096  4096    False    False  111.028650  15.548922
chengjunlu commented 7 months ago

The attributes of the prefetching GenISA.

call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %50, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
attributes #4 = { nounwind }

The GenISA is created in the MLIR LLVM dialect instead of the LLVM IR. I am not sure whether this would cause the probelm.

chengjunlu commented 7 months ago

The prefetching ops are generated as expected after updating the IGC to the new one.

load_block2d.ugm.d32.a64.ca.ca (1|M0)  null:0 [r11:1]      {I@2,$4} // ex_desc:0x0; desc:0x2080403 // $180

But the new IGC has correctness issue in the Triton kernel. The kernel doesn't return the correct result from the GEMM.

The performance doesn't make sense with the new driver (oneDNN performance decrease a lot.).

matmul-performance:
      M     N     K  TRANS_A  TRANS_B     oneDNN    Triton
0  2048   512   512    False    False   0.285490  0.323128
1  2048  1024  1024    False    False   1.337173  1.167114
2  2048  2048  2048    False    False   5.218729  3.638170
3  2048  4096  4096    False    False  19.639573  8.156155
chengjunlu commented 7 months ago

After review the disassemble code, there are some points could be enhanced for performance:

  1. Use the 2D store instead of the scatter store to remove the ConvertLayout in aligning the store value and the tensor of pointers.
  2. Use the packed type for the Dot operands layout encoding in TypeConvert to remove the redundant pack/un-pack register movement.
  3. The send message payload of the load/prefetching op seems composed repeatedly. We can improve this.
chengjunlu commented 7 months ago

The 2D store/load and prefetching ops has been enabled the on the GEMM kernel.

The known points we can do to improve the performance:

  1. The caching improvements:

    • The IGC driver doesn't generate the prefetching ops in the binary even we add the GenISA Prefetching 2D Ops. Need to wait the IGC driver updating.
    • Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.
  2. Use the packed type for the Dot operands layout encoding in TypeConvert to remove the redundant pack/un-pack register movement.

  3. The send message payload of the load/prefetching op seems composed repeatedly. We can improve this. (This is only available in SIMD for now. Need to investigate to require IGC to optimize this on scalar compiler backend.)

The performance for now.

matmul-performance:
        M       N       K      cuBLAS     Triton
0  2048.0   512.0   512.0    3.095257   2.726150
1  2048.0  1024.0  1024.0   11.607216   9.541525
2  2048.0  2048.0  2048.0   34.085901  23.184554
3  2048.0  4096.0  4096.0   88.441355  34.846204
4  2048.0  8192.0  8192.0  157.030987  32.682423
whitneywhtsang commented 7 months ago

As discussed, please work with @Dewei-Wang-sh to add a column with Triton (SIMD) measurement, and a column with Triton (SIMD) measurement without prefetch, to your example.

etiotto commented 7 months ago

The 2D store/load and prefetching ops has been enabled the on the GEMM kernel.

The known points we can do to improve the performance:

  1. The caching improvements:
  • The IGC driver doesn't generate the prefetching ops in the binary even we add the GenISA Prefetching 2D Ops. Need to wait the IGC driver updating.
  • Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.
  1. Use the packed type for the Dot operands layout encoding in TypeConvert to remove the redundant pack/un-pack register movement.
  2. The send message payload of the load/prefetching op seems composed repeatedly. We can improve this. (This is only available in SIMD for now. Need to investigate to require IGC to optimize this on scalar compiler backend.)

The performance for now.

matmul-performance:
        M       N       K      cuBLAS     Triton
0  2048.0   512.0   512.0    3.095257   2.726150
1  2048.0  1024.0  1024.0   11.607216   9.541525
2  2048.0  2048.0  2048.0   34.085901  23.184554
3  2048.0  4096.0  4096.0   88.441355  34.846204
4  2048.0  8192.0  8192.0  157.030987  32.682423

I see you are comparing to cuBLAS performance. I think we should compare against oneDNN performance (same platform). I'd like to understand the performance for the following scenarios:

  1. just using the DPAS instruction
  2. using the DPAS instruction + 2D block load
  3. using the DPAS instruction + 2D block load and 2D block stores
  4. using the DPAS instruction + 2D block load and 2D block stores + prefetch

For 4 you will have to wait for IGC to fic the problem you found (hope they are aware of this problem, correct ?)

etiotto commented 7 months ago

Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.

Can you explain in more details what you mean by using the "named barrier" and why you think using that operation/instruction would give better performance ?

chengjunlu commented 7 months ago

Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.

Can you explain in more details what you mean by using the "named barrier" and why you think using that operation/instruction would give better performance?

The different physical threads are doing the DOT operation on individual GEMM tile parallelly. The threads may be in different stages (Different iterator in the loops) and working on different region of the GEMM. Because we are using the multiple threads to prefetch the data cooperatively and implicitly. If no explicit synchronization, the cache may contain the un-expected data of different stages.

chengjunlu commented 7 months ago

I see you are comparing to cuBLAS performance. I think we should compare against oneDNN performance (same platform). I'd like to understand the performance for the following scenarios:

  1. just using the DPAS instruction
  2. using the DPAS instruction + 2D block load
  3. using the DPAS instruction + 2D block load and 2D block stores
  4. using the DPAS instruction + 2D block load and 2D block stores + prefetch

For 4 you will have to wait for IGC to fic the problem you found (hope they are aware of this problem, correct ?)

It is the oneDNN performance but the name of the column is cuBLAS by mistake.

Yes. The 4th point depends on the IGC's fix on the prefetching.

Dewei-Wang-sh commented 7 months ago

for the 08-experimental-block-pointer.py with gemm size 4k*4k*4k*fp16, with triton lowering to spirv(vc-intrinsics), the max perf is as below: 295tflops with all the optimizations: 2dload, 2dstore, 2dprefetch, etc (MLIR team 300 tflops, XeTLA 310tflops) 150tflops without prefetch

the machine info is PVC (Max 1550) EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64

chengjunlu commented 7 months ago

I meet an correctness issue in 2D block store.

The first time I use naive encoding for the tile_height=8, tile_width=16 and v_blocks=1 as what is used in 2D load encoding. But I cannot get the correct result on the memory.

genx.matrix.2Dblockstore %arg2, %654, %90, %656, %661, %662, %642 {elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32, vnni_transform = false} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) loc(#loc28)

The second time I use the encoding follows the GFX spec. for the tile_height=7, tile_width=15 and v_blocks=0. But it cause the GPU hang.

genx.matrix.2Dblockstore %arg2, %654, %90, %656, %661, %662, %642 {elem_size_in_bits = 32 : i32, tile_height = 7 : i32, tile_width = 16=5 : i32, transpose = false, v_blocks = 0 : i32, vnni_transform = false} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) loc(#loc28)

I will create a JIRA to track this 2D block store issue with a small reproducer to IGC. I think it should not impact the performance testing. Right?

chengjunlu commented 7 months ago

After aligning with DeWei, we use the same Triton kernel for the performance bench.

The SIMT kernel run with 2D block load/store, DPAS.

matmul-performance:
        M       N       K      oneDNN     Triton
0  2048.0   512.0   512.0    3.147170   2.781717
1  2048.0  1024.0  1024.0   11.766426   8.677456
2  2048.0  2048.0  2048.0   35.496352  17.951568
3  2048.0  4096.0  4096.0   89.484749  26.475945
4  4096.0  4096.0  4096.0  130.273616  29.551482
5  2048.0  8192.0  8192.0  153.008826  34.834623

The best performance is under this configuration.

triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
                      num_warps=64),

The performance is not good as expectation. Only about 20% to oneDNN on the jupyter hub platform. The expected performance would be 300tfps * 20% = 60tfps on the PVC 1550 for now.

Quickly reviewed the kernel with DeWei. We think there are redundant register movements instructions in the final binary may cause the performance issue. Need further clean up those redundant register movements instructions.

LLVM IR for simple GEMM case with two stages loop pipelining ``` ; ------------------------------------------------ ; OCL_asmacaca0935f08179e_optimized.ll ; ------------------------------------------------ target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" target triple = "spir64-unknown-unknown" ; Function Attrs: convergent nounwind define spir_kernel void @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(half addrspace(1)* %0, half addrspace(1)* %1, i8 addrspace(1)* %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, <8 x i32> %r0, <8 x i32> %payloadHeader, i16 %localIdX, i16 %localIdY, i16 %localIdZ, i8* %privateBase, i32 %bufferOffset, i32 %bufferOffset1, i32 %bufferOffset2) #0 !dbg !370 { %10 = extractelement <8 x i32> %r0, i64 1 %11 = add i32 %4, 127, !dbg !374 %is-neg = icmp slt i32 %11, 0, !dbg !378 br i1 %is-neg, label %cond-add, label %cond-add-join, !dbg !378 cond-add: ; preds = %9 %12 = add i32 %11, 127, !dbg !378 br label %cond-add-join, !dbg !378 cond-add-join: ; preds = %9, %cond-add %13 = phi i32 [ %11, %9 ], [ %12, %cond-add ], !dbg !378 %qot = ashr i32 %13, 7, !dbg !378 %14 = add i32 %3, 127, !dbg !379 %is-neg62 = icmp slt i32 %14, 0, !dbg !381 br i1 %is-neg62, label %cond-add63, label %cond-add-join64, !dbg !381 cond-add63: ; preds = %cond-add-join %15 = add i32 %14, 127, !dbg !381 br label %cond-add-join64, !dbg !381 cond-add-join64: ; preds = %cond-add-join, %cond-add63 %16 = phi i32 [ %14, %cond-add-join ], [ %15, %cond-add63 ], !dbg !381 %qot65 = ashr i32 %16, 7, !dbg !381 %17 = shl nsw i32 %qot, 3, !dbg !382 %18 = sdiv i32 %10, %17, !dbg !383 %19 = shl i32 %18, 3, !dbg !384 %20 = sub i32 %qot65, %19, !dbg !385 %21 = icmp slt i32 %20, 8, !dbg !385 %22 = select i1 %21, i32 %20, i32 8, !dbg !386 %23 = srem i32 %10, %22, !dbg !388 %24 = add i32 %19, %23, !dbg !389 %25 = srem i32 %10, %17, !dbg !390 %26 = sdiv i32 %25, %22, !dbg !391 %27 = shl i32 %24, 7, !dbg !392 %28 = shl i32 %26, 7, !dbg !393 %localIdX3 = zext i16 %localIdX to i32, !dbg !394 %29 = lshr i32 %localIdX3, 4, !dbg !394 %30 = and i32 %29, 7, !dbg !394 %31 = lshr i32 %localIdX3, 7, !dbg !394 %32 = and i32 %31, 7, !dbg !394 %33 = mul i32 %27, %6, !dbg !394 %34 = sext i32 %33 to i64, !dbg !394 %35 = getelementptr half, half addrspace(1)* %0, i64 %34, !dbg !394 %36 = shl i32 %5, 1, !dbg !394 %37 = add i32 %36, -1, !dbg !394 %38 = add i32 %3, -1, !dbg !394 %39 = shl i32 %6, 1, !dbg !394 %40 = add i32 %39, -1, !dbg !394 %41 = shl nuw nsw i32 %30, 3, !dbg !394 %42 = shl nuw nsw i32 %32, 4, !dbg !394 %43 = ptrtoint half addrspace(1)* %35 to i64, !dbg !394 call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %43, i32 %37, i32 %38, i32 %40, i32 %41, i32 %42, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #3, !dbg !394 %44 = sext i32 %28 to i64, !dbg !395 %45 = getelementptr half, half addrspace(1)* %1, i64 %44, !dbg !395 %46 = shl i32 %4, 1, !dbg !395 %47 = add i32 %46, -1, !dbg !395 %48 = add i32 %5, -1, !dbg !395 %49 = shl i32 %7, 1, !dbg !395 %50 = add i32 %49, -1, !dbg !395 %51 = shl nuw nsw i32 %30, 4, !dbg !395 %52 = shl nuw nsw i32 %32, 3, !dbg !395 %53 = ptrtoint half addrspace(1)* %45 to i64, !dbg !395 call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %53, i32 %47, i32 %48, i32 %50, i32 %51, i32 %52, i32 32, i32 7, i32 7, i32 1, i1 false, i1 false, i32 4) #3, !dbg !395 %54 = icmp sgt i32 %5, 0, !dbg !396 br i1 %54, label %.lr.ph, label %._crit_edge53, !dbg !396 ._crit_edge53: ; preds = %cond-add-join64 %.pre = lshr i32 %localIdX3, 3, !dbg !397 %.pre54 = and i32 %.pre, 120, !dbg !397 %.pre56 = and i32 %localIdX3, 48, !dbg !397 %.pre58 = or i32 %.pre56, %28, !dbg !397 %.pre60 = or i32 %.pre54, %27, !dbg !397 br label %._crit_edge, !dbg !396 .lr.ph: ; preds = %cond-add-join64 %55 = lshr i32 %localIdX3, 3 %56 = and i32 %55, 120 %57 = or i32 %56, %27 %58 = ptrtoint half addrspace(1)* %0 to i64 %59 = and i32 %localIdX3, 48 %60 = or i32 %59, %28 %61 = ptrtoint half addrspace(1)* %1 to i64 %62 = or i32 %60, 64 br label %63, !dbg !396 63: ; preds = %.lr.ph, %63 %64 = phi i32 [ 0, %.lr.ph ], [ %84, %63 ] %65 = phi i32 [ 0, %.lr.ph ], [ %83, %63 ] %66 = phi float [ 0.000000e+00, %.lr.ph ], [ %149, %63 ] %67 = phi float [ 0.000000e+00, %.lr.ph ], [ %148, %63 ] %68 = phi float [ 0.000000e+00, %.lr.ph ], [ %147, %63 ] %69 = phi float [ 0.000000e+00, %.lr.ph ], [ %146, %63 ] %70 = phi float [ 0.000000e+00, %.lr.ph ], [ %145, %63 ] %71 = phi float [ 0.000000e+00, %.lr.ph ], [ %144, %63 ] %72 = phi float [ 0.000000e+00, %.lr.ph ], [ %143, %63 ] %73 = phi float [ 0.000000e+00, %.lr.ph ], [ %142, %63 ] %74 = phi float [ 0.000000e+00, %.lr.ph ], [ %140, %63 ] %75 = phi float [ 0.000000e+00, %.lr.ph ], [ %139, %63 ] %76 = phi float [ 0.000000e+00, %.lr.ph ], [ %138, %63 ] %77 = phi float [ 0.000000e+00, %.lr.ph ], [ %137, %63 ] %78 = phi float [ 0.000000e+00, %.lr.ph ], [ %136, %63 ] %79 = phi float [ 0.000000e+00, %.lr.ph ], [ %135, %63 ] %80 = phi float [ 0.000000e+00, %.lr.ph ], [ %134, %63 ] %81 = phi float [ 0.000000e+00, %.lr.ph ], [ %133, %63 ] %82 = phi i32 [ 0, %.lr.ph ], [ %150, %63 ] %83 = add i32 %65, 64, !dbg !398 %84 = add i32 %64, 64, !dbg !399 %85 = sext i32 %83 to i64, !dbg !394 %86 = getelementptr half, half addrspace(1)* %35, i64 %85, !dbg !394 %87 = ptrtoint half addrspace(1)* %86 to i64, !dbg !394 call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %87, i32 %37, i32 %38, i32 %40, i32 %41, i32 %42, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #3, !dbg !394 %88 = mul i32 %84, %7, !dbg !395 %89 = sext i32 %88 to i64, !dbg !395 %90 = getelementptr half, half addrspace(1)* %45, i64 %89, !dbg !395 %91 = ptrtoint half addrspace(1)* %90 to i64, !dbg !395 call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %91, i32 %47, i32 %48, i32 %50, i32 %51, i32 %52, i32 32, i32 7, i32 7, i32 1, i1 false, i1 false, i32 4) #3, !dbg !395 %92 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %58, i32 %37, i32 %38, i32 %40, i32 %65, i32 %57, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #3, !dbg !394 %93 = or i32 %65, 16, !dbg !394 %94 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %58, i32 %37, i32 %38, i32 %40, i32 %93, i32 %57, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #3, !dbg !394 %95 = or i32 %65, 32, !dbg !394 %96 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %58, i32 %37, i32 %38, i32 %40, i32 %95, i32 %57, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #3, !dbg !394 %97 = or i32 %65, 48, !dbg !394 %98 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %58, i32 %37, i32 %38, i32 %40, i32 %97, i32 %57, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #3, !dbg !394 %99 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %60, i32 %64, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %100 = or i32 %64, 16, !dbg !395 %101 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %60, i32 %100, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %102 = or i32 %64, 32, !dbg !395 %103 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %60, i32 %102, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %104 = or i32 %64, 48, !dbg !395 %105 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %60, i32 %104, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %106 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %62, i32 %64, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %107 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %62, i32 %100, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %108 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %62, i32 %102, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %109 = call <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64 %61, i32 %47, i32 %48, i32 %50, i32 %62, i32 %104, i32 16, i32 16, i32 16, i32 1, i1 false, i1 true, i32 0) #3, !dbg !395 %110 = insertelement <8 x float> undef, float %81, i64 0, !dbg !400 %111 = insertelement <8 x float> %110, float %80, i64 1, !dbg !400 %112 = insertelement <8 x float> %111, float %79, i64 2, !dbg !400 %113 = insertelement <8 x float> %112, float %78, i64 3, !dbg !400 %114 = insertelement <8 x float> %113, float %77, i64 4, !dbg !400 %115 = insertelement <8 x float> %114, float %76, i64 5, !dbg !400 %116 = insertelement <8 x float> %115, float %75, i64 6, !dbg !400 %117 = insertelement <8 x float> %116, float %74, i64 7, !dbg !400 %118 = insertelement <8 x float> undef, float %73, i64 0, !dbg !400 %119 = insertelement <8 x float> %118, float %72, i64 1, !dbg !400 %120 = insertelement <8 x float> %119, float %71, i64 2, !dbg !400 %121 = insertelement <8 x float> %120, float %70, i64 3, !dbg !400 %122 = insertelement <8 x float> %121, float %69, i64 4, !dbg !400 %123 = insertelement <8 x float> %122, float %68, i64 5, !dbg !400 %124 = insertelement <8 x float> %123, float %67, i64 6, !dbg !400 %125 = insertelement <8 x float> %124, float %66, i64 7, !dbg !400 %126 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %117, <8 x i16> %92, <8 x i32> %99, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %127 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %125, <8 x i16> %92, <8 x i32> %106, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %128 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %126, <8 x i16> %94, <8 x i32> %101, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %129 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %127, <8 x i16> %94, <8 x i32> %107, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %130 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %128, <8 x i16> %96, <8 x i32> %103, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %131 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %129, <8 x i16> %96, <8 x i32> %108, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %132 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %130, <8 x i16> %98, <8 x i32> %105, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %133 = extractelement <8 x float> %132, i64 0, !dbg !400 %134 = extractelement <8 x float> %132, i64 1, !dbg !400 %135 = extractelement <8 x float> %132, i64 2, !dbg !400 %136 = extractelement <8 x float> %132, i64 3, !dbg !400 %137 = extractelement <8 x float> %132, i64 4, !dbg !400 %138 = extractelement <8 x float> %132, i64 5, !dbg !400 %139 = extractelement <8 x float> %132, i64 6, !dbg !400 %140 = extractelement <8 x float> %132, i64 7, !dbg !400 %141 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %131, <8 x i16> %98, <8 x i32> %109, i32 10, i32 10, i32 8, i32 8, i1 false) #3, !dbg !400 %142 = extractelement <8 x float> %141, i64 0, !dbg !400 %143 = extractelement <8 x float> %141, i64 1, !dbg !400 %144 = extractelement <8 x float> %141, i64 2, !dbg !400 %145 = extractelement <8 x float> %141, i64 3, !dbg !400 %146 = extractelement <8 x float> %141, i64 4, !dbg !400 %147 = extractelement <8 x float> %141, i64 5, !dbg !400 %148 = extractelement <8 x float> %141, i64 6, !dbg !400 %149 = extractelement <8 x float> %141, i64 7, !dbg !400 %150 = add i32 %82, 64, !dbg !396 %151 = icmp slt i32 %150, %5, !dbg !396 br i1 %151, label %63, label %._crit_edge, !dbg !396 ._crit_edge: ; preds = %63, %._crit_edge53 %.pre-phi61 = phi i32 [ %.pre60, %._crit_edge53 ], [ %57, %63 ], !dbg !397 %.pre-phi59 = phi i32 [ %.pre58, %._crit_edge53 ], [ %60, %63 ], !dbg !397 %.pre-phi57 = phi i32 [ %.pre56, %._crit_edge53 ], [ %59, %63 ], !dbg !397 %.lcssa21 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %133, %63 ], !dbg !400 %.lcssa20 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %134, %63 ], !dbg !400 %.lcssa19 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %135, %63 ], !dbg !400 %.lcssa18 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %136, %63 ], !dbg !400 %.lcssa17 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %137, %63 ], !dbg !400 %.lcssa16 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %138, %63 ], !dbg !400 %.lcssa15 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %139, %63 ], !dbg !400 %.lcssa14 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %140, %63 ], !dbg !400 %.lcssa13 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %142, %63 ], !dbg !400 %.lcssa12 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %143, %63 ], !dbg !400 %.lcssa11 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %144, %63 ], !dbg !400 %.lcssa10 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %145, %63 ], !dbg !400 %.lcssa9 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %146, %63 ], !dbg !400 %.lcssa8 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %147, %63 ], !dbg !400 %.lcssa7 = phi float [ 0.000000e+00, %._crit_edge53 ], [ %148, %63 ], !dbg !400 %.lcssa = phi float [ 0.000000e+00, %._crit_edge53 ], [ %149, %63 ], !dbg !400 %152 = shl i32 %4, 2, !dbg !397 %153 = add i32 %152, -1, !dbg !397 %154 = shl i32 %8, 2, !dbg !397 %155 = add i32 %154, -1, !dbg !397 %156 = insertelement <8 x float> undef, float %.lcssa21, i64 0, !dbg !397 %157 = insertelement <8 x float> %156, float %.lcssa20, i64 1, !dbg !397 %158 = insertelement <8 x float> %157, float %.lcssa19, i64 2, !dbg !397 %159 = insertelement <8 x float> %158, float %.lcssa18, i64 3, !dbg !397 %160 = insertelement <8 x float> %159, float %.lcssa17, i64 4, !dbg !397 %161 = insertelement <8 x float> %160, float %.lcssa16, i64 5, !dbg !397 %162 = insertelement <8 x float> %161, float %.lcssa15, i64 6, !dbg !397 %163 = insertelement <8 x float> %162, float %.lcssa14, i64 7, !dbg !397 %164 = bitcast <8 x float> %163 to <8 x i32>, !dbg !397 %165 = ptrtoint i8 addrspace(1)* %2 to i64, !dbg !397 call void @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64 %165, i32 %153, i32 %38, i32 %155, i32 %.pre-phi59, i32 %.pre-phi61, i32 32, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0, <8 x i32> %164) #3, !dbg !397 %166 = add nuw nsw i32 %.pre-phi57, 32, !dbg !397 %167 = or i32 %166, %28, !dbg !397 %168 = insertelement <8 x float> undef, float %.lcssa13, i64 0, !dbg !397 %169 = insertelement <8 x float> %168, float %.lcssa12, i64 1, !dbg !397 %170 = insertelement <8 x float> %169, float %.lcssa11, i64 2, !dbg !397 %171 = insertelement <8 x float> %170, float %.lcssa10, i64 3, !dbg !397 %172 = insertelement <8 x float> %171, float %.lcssa9, i64 4, !dbg !397 %173 = insertelement <8 x float> %172, float %.lcssa8, i64 5, !dbg !397 %174 = insertelement <8 x float> %173, float %.lcssa7, i64 6, !dbg !397 %175 = insertelement <8 x float> %174, float %.lcssa, i64 7, !dbg !397 %176 = bitcast <8 x float> %175 to <8 x i32>, !dbg !397 call void @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64 %165, i32 %153, i32 %38, i32 %155, i32 %167, i32 %.pre-phi61, i32 32, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0, <8 x i32> %176) #3, !dbg !397 ret void, !dbg !401 } ; Function Desc: declare void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: ; Arg 8: ; Arg 9: ; Arg 10: ; Arg 11: declare <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: ; Arg 8: ; Arg 9: ; Arg 10: ; Arg 11: declare <8 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: declare <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float>, <8 x i16>, <8 x i32>, i32, i32, i32, i32, i1) ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: ; Arg 8: ; Arg 9: ; Arg 10: ; Arg 11: ; Arg 12: declare void @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, <8 x i32>) ; Function Attrs: convergent mustprogress nofree nounwind readnone willreturn declare spir_func i32 @__builtin_IB_get_group_id(i32 noundef) local_unnamed_addr #1 ; Function Attrs: convergent mustprogress nofree nounwind readnone willreturn declare spir_func i32 @__builtin_IB_get_local_id_x() local_unnamed_addr #1 ; Function Desc: ; Output: ; Function Attrs: nounwind readnone declare void @llvm.genx.GenISA.CatchAllDebugLine() #2 attributes #0 = { convergent nounwind } attributes #1 = { convergent mustprogress nofree nounwind readnone willreturn "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } attributes #2 = { nounwind readnone } attributes #3 = { nounwind } !llvm.module.flags = !{!0, !1, !2} !llvm.dbg.cu = !{!3} !spirv.MemoryModel = !{!5} !spirv.Source = !{!6} !spirv.Generator = !{!7} !igc.functions = !{!8} !IGCMetadata = !{!25} !opencl.ocl.version = !{!368, !368, !368, !368, !368} !opencl.spir.version = !{!368, !368, !368, !368, !368} !llvm.ident = !{!369, !369, !369, !369, !369} !0 = !{i32 7, !"Dwarf Version", i32 0} !1 = !{i32 2, !"Debug Info Version", i32 3} !2 = !{i32 1, !"wchar_size", i32 4} !3 = distinct !DICompileUnit(language: DW_LANG_OpenCL, file: !4, producer: "triton", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug) !4 = !DIFile(filename: "10-experimental-tma-store-matrix-multiplication.py", directory: "/home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials") !5 = !{i32 2, i32 2} !6 = !{i32 4, i32 100000} !7 = !{i16 6, i16 14} !8 = !{void (half addrspace(1)*, half addrspace(1)*, i8 addrspace(1)*, i32, i32, i32, i32, i32, i32, <8 x i32>, <8 x i32>, i16, i16, i16, i8*, i32, i32, i32)* @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c, !9} !9 = !{!10, !11, !24} !10 = !{!"function_type", i32 0} !11 = !{!"implicit_arg_desc", !12, !13, !14, !15, !16, !17, !18, !20, !22} !12 = !{i32 0} !13 = !{i32 1} !14 = !{i32 7} !15 = !{i32 8} !16 = !{i32 9} !17 = !{i32 12} !18 = !{i32 14, !19} !19 = !{!"explicit_arg_num", i32 0} !20 = !{i32 14, !21} !21 = !{!"explicit_arg_num", i32 1} !22 = !{i32 14, !23} !23 = !{!"explicit_arg_num", i32 2} !24 = !{!"sub_group_size", i32 16} !25 = !{!"ModuleMD", !26, !27, !97, !235, !265, !281, !298, !308, !310, !311, !324, !325, !326, !327, !331, !332, !333, !334, !335, !336, !337, !338, !339, !340, !341, !342, !343, !344, !346, !350, !351, !352, !353, !354, !355, !356, !357, !358, !163, !359, !360, !361, !363, !366, !367} !26 = !{!"isPrecise", i1 false} !27 = !{!"compOpt", !28, !29, !30, !31, !32, !33, !34, !35, !36, !37, !38, !39, !40, !41, !42, !43, !44, !45, !46, !47, !48, !49, !50, !51, !52, !53, !54, !55, !56, !57, !58, !59, !60, !61, !62, !63, !64, !65, !66, !67, !68, !69, !70, !71, !72, !73, !74, !75, !76, !77, !78, !79, !80, !81, !82, !83, !84, !85, !86, !87, !88, !89, !90, !91, !92, !93, !94, !95, !96} !28 = !{!"DenormsAreZero", i1 false} !29 = !{!"CorrectlyRoundedDivSqrt", i1 false} !30 = !{!"OptDisable", i1 false} !31 = !{!"MadEnable", i1 false} !32 = !{!"NoSignedZeros", i1 false} !33 = !{!"NoNaNs", i1 false} !34 = !{!"FloatRoundingMode", i32 0} !35 = !{!"FloatCvtIntRoundingMode", i32 3} !36 = !{!"LoadCacheDefault", i32 4} !37 = !{!"StoreCacheDefault", i32 2} !38 = !{!"VISAPreSchedRPThreshold", i32 0} !39 = !{!"SetLoopUnrollThreshold", i32 0} !40 = !{!"UnsafeMathOptimizations", i1 false} !41 = !{!"disableCustomUnsafeOpts", i1 false} !42 = !{!"disableReducePow", i1 false} !43 = !{!"FiniteMathOnly", i1 false} !44 = !{!"FastRelaxedMath", i1 false} !45 = !{!"DashGSpecified", i1 false} !46 = !{!"FastCompilation", i1 false} !47 = !{!"UseScratchSpacePrivateMemory", i1 true} !48 = !{!"RelaxedBuiltins", i1 false} !49 = !{!"SubgroupIndependentForwardProgressRequired", i1 true} !50 = !{!"GreaterThan2GBBufferRequired", i1 true} !51 = !{!"GreaterThan4GBBufferRequired", i1 true} !52 = !{!"DisableA64WA", i1 false} !53 = !{!"ForceEnableA64WA", i1 false} !54 = !{!"PushConstantsEnable", i1 true} !55 = !{!"HasPositivePointerOffset", i1 false} !56 = !{!"HasBufferOffsetArg", i1 true} !57 = !{!"BufferOffsetArgOptional", i1 true} !58 = !{!"replaceGlobalOffsetsByZero", i1 false} !59 = !{!"forcePixelShaderSIMDMode", i32 0} !60 = !{!"pixelShaderDoNotAbortOnSpill", i1 false} !61 = !{!"UniformWGS", i1 false} !62 = !{!"disableVertexComponentPacking", i1 false} !63 = !{!"disablePartialVertexComponentPacking", i1 false} !64 = !{!"PreferBindlessImages", i1 false} !65 = !{!"UseBindlessMode", i1 false} !66 = !{!"UseLegacyBindlessMode", i1 true} !67 = !{!"disableMathRefactoring", i1 false} !68 = !{!"atomicBranch", i1 false} !69 = !{!"ForceInt32DivRemEmu", i1 false} !70 = !{!"ForceInt32DivRemEmuSP", i1 false} !71 = !{!"DisableFastestSingleCSSIMD", i1 false} !72 = !{!"DisableFastestLinearScan", i1 false} !73 = !{!"UseStatelessforPrivateMemory", i1 false} !74 = !{!"EnableTakeGlobalAddress", i1 false} !75 = !{!"IsLibraryCompilation", i1 false} !76 = !{!"FastVISACompile", i1 false} !77 = !{!"MatchSinCosPi", i1 false} !78 = !{!"ExcludeIRFromZEBinary", i1 false} !79 = !{!"EmitZeBinVISASections", i1 false} !80 = !{!"FP64GenEmulationEnabled", i1 false} !81 = !{!"allowDisableRematforCS", i1 false} !82 = !{!"DisableIncSpillCostAllAddrTaken", i1 false} !83 = !{!"DisableCPSOmaskWA", i1 false} !84 = !{!"DisableFastestGopt", i1 false} !85 = !{!"WaForceHalfPromotionComputeShader", i1 false} !86 = !{!"WaForceHalfPromotionPixelVertexShader", i1 false} !87 = !{!"DisableConstantCoalescing", i1 false} !88 = !{!"EnableUndefAlphaOutputAsRed", i1 true} !89 = !{!"WaEnableALTModeVisaWA", i1 false} !90 = !{!"NewSpillCostFunction", i1 false} !91 = !{!"ForceLargeGRFNum4RQ", i1 false} !92 = !{!"DisableEUFusion", i1 false} !93 = !{!"DisableFDivToFMulInvOpt", i1 false} !94 = !{!"initializePhiSampleSourceWA", i1 false} !95 = !{!"WaDisableSubspanUseNoMaskForCB", i1 false} !96 = !{!"FastestS1Options", i32 0} !97 = !{!"FuncMD", !98, !99} !98 = !{!"FuncMDMap[0]", void (half addrspace(1)*, half addrspace(1)*, i8 addrspace(1)*, i32, i32, i32, i32, i32, i32, <8 x i32>, <8 x i32>, i16, i16, i16, i8*, i32, i32, i32)* @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c} !99 = !{!"FuncMDValue[0]", !100, !101, !105, !106, !107, !128, !155, !156, !157, !158, !159, !160, !161, !162, !163, !164, !165, !166, !167, !168, !169, !170, !180, !190, !200, !210, !220, !230, !231} !100 = !{!"localOffsets"} !101 = !{!"workGroupWalkOrder", !102, !103, !104} !102 = !{!"dim0", i32 0} !103 = !{!"dim1", i32 0} !104 = !{!"dim2", i32 0} !105 = !{!"funcArgs"} !106 = !{!"functionType", !"KernelFunction"} !107 = !{!"rtInfo", !108, !109, !110, !111, !112, !113, !114, !115, !116, !117, !118, !119, !120, !121, !122, !123, !127} !108 = !{!"callableShaderType", !"NumberOfCallableShaderTypes"} !109 = !{!"isContinuation", i1 false} !110 = !{!"hasTraceRayPayload", i1 false} !111 = !{!"hasHitAttributes", i1 false} !112 = !{!"hasCallableData", i1 false} !113 = !{!"ShaderStackSize", i32 0} !114 = !{!"ShaderHash", i64 0} !115 = !{!"ShaderName", !""} !116 = !{!"ParentName", !""} !117 = !{!"SlotNum", i1* null} !118 = !{!"NOSSize", i32 0} !119 = !{!"globalRootSignatureSize", i32 0} !120 = !{!"Entries"} !121 = !{!"SpillUnions"} !122 = !{!"CustomHitAttrSizeInBytes", i32 0} !123 = !{!"Types", !124, !125, !126} !124 = !{!"FrameStartTys"} !125 = !{!"ArgumentTys"} !126 = !{!"FullFrameTys"} !127 = !{!"Aliases"} !128 = !{!"resAllocMD", !129, !130, !131, !132, !154} !129 = !{!"uavsNumType", i32 0} !130 = !{!"srvsNumType", i32 0} !131 = !{!"samplersNumType", i32 0} !132 = !{!"argAllocMDList", !133, !137, !138, !139, !140, !141, !142, !143, !144, !145, !146, !147, !148, !149, !150, !151, !152, !153} !133 = !{!"argAllocMDListVec[0]", !134, !135, !136} !134 = !{!"type", i32 0} !135 = !{!"extensionType", i32 -1} !136 = !{!"indexType", i32 -1} !137 = !{!"argAllocMDListVec[1]", !134, !135, !136} !138 = !{!"argAllocMDListVec[2]", !134, !135, !136} !139 = !{!"argAllocMDListVec[3]", !134, !135, !136} !140 = !{!"argAllocMDListVec[4]", !134, !135, !136} !141 = !{!"argAllocMDListVec[5]", !134, !135, !136} !142 = !{!"argAllocMDListVec[6]", !134, !135, !136} !143 = !{!"argAllocMDListVec[7]", !134, !135, !136} !144 = !{!"argAllocMDListVec[8]", !134, !135, !136} !145 = !{!"argAllocMDListVec[9]", !134, !135, !136} !146 = !{!"argAllocMDListVec[10]", !134, !135, !136} !147 = !{!"argAllocMDListVec[11]", !134, !135, !136} !148 = !{!"argAllocMDListVec[12]", !134, !135, !136} !149 = !{!"argAllocMDListVec[13]", !134, !135, !136} !150 = !{!"argAllocMDListVec[14]", !134, !135, !136} !151 = !{!"argAllocMDListVec[15]", !134, !135, !136} !152 = !{!"argAllocMDListVec[16]", !134, !135, !136} !153 = !{!"argAllocMDListVec[17]", !134, !135, !136} !154 = !{!"inlineSamplersMD"} !155 = !{!"maxByteOffsets"} !156 = !{!"IsInitializer", i1 false} !157 = !{!"IsFinalizer", i1 false} !158 = !{!"CompiledSubGroupsNumber", i32 0} !159 = !{!"hasInlineVmeSamplers", i1 false} !160 = !{!"localSize", i32 0} !161 = !{!"localIDPresent", i1 true} !162 = !{!"groupIDPresent", i1 true} !163 = !{!"privateMemoryPerWI", i32 0} !164 = !{!"globalIDPresent", i1 false} !165 = !{!"hasSyncRTCalls", i1 false} !166 = !{!"hasNonKernelArgLoad", i1 false} !167 = !{!"hasNonKernelArgStore", i1 false} !168 = !{!"hasNonKernelArgAtomic", i1 false} !169 = !{!"UserAnnotations"} !170 = !{!"m_OpenCLArgAddressSpaces", !171, !172, !173, !174, !175, !176, !177, !178, !179} !171 = !{!"m_OpenCLArgAddressSpacesVec[0]", i32 1} !172 = !{!"m_OpenCLArgAddressSpacesVec[1]", i32 1} !173 = !{!"m_OpenCLArgAddressSpacesVec[2]", i32 1} !174 = !{!"m_OpenCLArgAddressSpacesVec[3]", i32 0} !175 = !{!"m_OpenCLArgAddressSpacesVec[4]", i32 0} !176 = !{!"m_OpenCLArgAddressSpacesVec[5]", i32 0} !177 = !{!"m_OpenCLArgAddressSpacesVec[6]", i32 0} !178 = !{!"m_OpenCLArgAddressSpacesVec[7]", i32 0} !179 = !{!"m_OpenCLArgAddressSpacesVec[8]", i32 0} !180 = !{!"m_OpenCLArgAccessQualifiers", !181, !182, !183, !184, !185, !186, !187, !188, !189} !181 = !{!"m_OpenCLArgAccessQualifiersVec[0]", !"none"} !182 = !{!"m_OpenCLArgAccessQualifiersVec[1]", !"none"} !183 = !{!"m_OpenCLArgAccessQualifiersVec[2]", !"none"} !184 = !{!"m_OpenCLArgAccessQualifiersVec[3]", !"none"} !185 = !{!"m_OpenCLArgAccessQualifiersVec[4]", !"none"} !186 = !{!"m_OpenCLArgAccessQualifiersVec[5]", !"none"} !187 = !{!"m_OpenCLArgAccessQualifiersVec[6]", !"none"} !188 = !{!"m_OpenCLArgAccessQualifiersVec[7]", !"none"} !189 = !{!"m_OpenCLArgAccessQualifiersVec[8]", !"none"} !190 = !{!"m_OpenCLArgTypes", !191, !192, !193, !194, !195, !196, !197, !198, !199} !191 = !{!"m_OpenCLArgTypesVec[0]", !"half*"} !192 = !{!"m_OpenCLArgTypesVec[1]", !"half*"} !193 = !{!"m_OpenCLArgTypesVec[2]", !"char*"} !194 = !{!"m_OpenCLArgTypesVec[3]", !"int"} !195 = !{!"m_OpenCLArgTypesVec[4]", !"int"} !196 = !{!"m_OpenCLArgTypesVec[5]", !"int"} !197 = !{!"m_OpenCLArgTypesVec[6]", !"int"} !198 = !{!"m_OpenCLArgTypesVec[7]", !"int"} !199 = !{!"m_OpenCLArgTypesVec[8]", !"int"} !200 = !{!"m_OpenCLArgBaseTypes", !201, !202, !203, !204, !205, !206, !207, !208, !209} !201 = !{!"m_OpenCLArgBaseTypesVec[0]", !"half*"} !202 = !{!"m_OpenCLArgBaseTypesVec[1]", !"half*"} !203 = !{!"m_OpenCLArgBaseTypesVec[2]", !"char*"} !204 = !{!"m_OpenCLArgBaseTypesVec[3]", !"int"} !205 = !{!"m_OpenCLArgBaseTypesVec[4]", !"int"} !206 = !{!"m_OpenCLArgBaseTypesVec[5]", !"int"} !207 = !{!"m_OpenCLArgBaseTypesVec[6]", !"int"} !208 = !{!"m_OpenCLArgBaseTypesVec[7]", !"int"} !209 = !{!"m_OpenCLArgBaseTypesVec[8]", !"int"} !210 = !{!"m_OpenCLArgTypeQualifiers", !211, !212, !213, !214, !215, !216, !217, !218, !219} !211 = !{!"m_OpenCLArgTypeQualifiersVec[0]", !""} !212 = !{!"m_OpenCLArgTypeQualifiersVec[1]", !""} !213 = !{!"m_OpenCLArgTypeQualifiersVec[2]", !""} !214 = !{!"m_OpenCLArgTypeQualifiersVec[3]", !""} !215 = !{!"m_OpenCLArgTypeQualifiersVec[4]", !""} !216 = !{!"m_OpenCLArgTypeQualifiersVec[5]", !""} !217 = !{!"m_OpenCLArgTypeQualifiersVec[6]", !""} !218 = !{!"m_OpenCLArgTypeQualifiersVec[7]", !""} !219 = !{!"m_OpenCLArgTypeQualifiersVec[8]", !""} !220 = !{!"m_OpenCLArgNames", !221, !222, !223, !224, !225, !226, !227, !228, !229} !221 = !{!"m_OpenCLArgNamesVec[0]", !""} !222 = !{!"m_OpenCLArgNamesVec[1]", !""} !223 = !{!"m_OpenCLArgNamesVec[2]", !""} !224 = !{!"m_OpenCLArgNamesVec[3]", !""} !225 = !{!"m_OpenCLArgNamesVec[4]", !""} !226 = !{!"m_OpenCLArgNamesVec[5]", !""} !227 = !{!"m_OpenCLArgNamesVec[6]", !""} !228 = !{!"m_OpenCLArgNamesVec[7]", !""} !229 = !{!"m_OpenCLArgNamesVec[8]", !""} !230 = !{!"m_OpenCLArgScalarAsPointers"} !231 = !{!"m_OptsToDisablePerFunc", !232, !233, !234} !232 = !{!"m_OptsToDisablePerFuncSet[0]", !"IGC-AddressArithmeticSinking"} !233 = !{!"m_OptsToDisablePerFuncSet[1]", !"IGC-AllowSimd32Slicing"} !234 = !{!"m_OptsToDisablePerFuncSet[2]", !"IGC-SinkLoadOpt"} !235 = !{!"pushInfo", !236, !237, !238, !241, !242, !243, !244, !245, !246, !247, !248, !261, !262, !263, !264} !236 = !{!"pushableAddresses"} !237 = !{!"bindlessPushInfo"} !238 = !{!"dynamicBufferInfo", !239, !240} !239 = !{!"firstIndex", i32 0} !240 = !{!"numOffsets", i32 0} !241 = !{!"MaxNumberOfPushedBuffers", i32 0} !242 = !{!"inlineConstantBufferSlot", i32 -1} !243 = !{!"inlineConstantBufferOffset", i32 -1} !244 = !{!"inlineConstantBufferGRFOffset", i32 -1} !245 = !{!"constants"} !246 = !{!"inputs"} !247 = !{!"constantReg"} !248 = !{!"simplePushInfoArr", !249, !258, !259, !260} !249 = !{!"simplePushInfoArrVec[0]", !250, !251, !252, !253, !254, !255, !256, !257} !250 = !{!"cbIdx", i32 0} !251 = !{!"pushableAddressGrfOffset", i32 -1} !252 = !{!"pushableOffsetGrfOffset", i32 -1} !253 = !{!"offset", i32 0} !254 = !{!"size", i32 0} !255 = !{!"isStateless", i1 false} !256 = !{!"isBindless", i1 false} !257 = !{!"simplePushLoads"} !258 = !{!"simplePushInfoArrVec[1]", !250, !251, !252, !253, !254, !255, !256, !257} !259 = !{!"simplePushInfoArrVec[2]", !250, !251, !252, !253, !254, !255, !256, !257} !260 = !{!"simplePushInfoArrVec[3]", !250, !251, !252, !253, !254, !255, !256, !257} !261 = !{!"simplePushBufferUsed", i32 0} !262 = !{!"pushAnalysisWIInfos"} !263 = !{!"inlineRTGlobalPtrOffset", i32 0} !264 = !{!"rtSyncSurfPtrOffset", i32 0} !265 = !{!"psInfo", !266, !267, !268, !269, !270, !271, !272, !273, !274, !275, !276, !277, !278, !279, !280} !266 = !{!"BlendStateDisabledMask", i8 0} !267 = !{!"SkipSrc0Alpha", i1 false} !268 = !{!"DualSourceBlendingDisabled", i1 false} !269 = !{!"ForceEnableSimd32", i1 false} !270 = !{!"outputDepth", i1 false} !271 = !{!"outputStencil", i1 false} !272 = !{!"outputMask", i1 false} !273 = !{!"blendToFillEnabled", i1 false} !274 = !{!"forceEarlyZ", i1 false} !275 = !{!"hasVersionedLoop", i1 false} !276 = !{!"forceSingleSourceRTWAfterDualSourceRTW", i1 false} !277 = !{!"NumSamples", i8 0} !278 = !{!"blendOptimizationMode"} !279 = !{!"colorOutputMask"} !280 = !{!"WaDisableVRS", i1 false} !281 = !{!"csInfo", !282, !283, !284, !285, !286, !38, !39, !287, !288, !289, !290, !291, !292, !293, !294, !68, !295, !296, !297} !282 = !{!"maxWorkGroupSize", i32 0} !283 = !{!"waveSize", i32 0} !284 = !{!"ComputeShaderSecondCompile"} !285 = !{!"forcedSIMDSize", i8 0} !286 = !{!"forceTotalGRFNum", i32 0} !287 = !{!"allowLowerSimd", i1 false} !288 = !{!"disableSimd32Slicing", i1 false} !289 = !{!"disableSplitOnSpill", i1 false} !290 = !{!"forcedVISAPreRAScheduler", i1 false} !291 = !{!"disableLocalIdOrderOptimizations", i1 false} !292 = !{!"disableDispatchAlongY", i1 false} !293 = !{!"neededThreadIdLayout", i1* null} !294 = !{!"forceTileYWalk", i1 false} !295 = !{!"walkOrderEnabled", i1 false} !296 = !{!"walkOrderOverride", i32 0} !297 = !{!"ResForHfPacking"} !298 = !{!"msInfo", !299, !300, !301, !302, !303, !304, !305, !306, !307} !299 = !{!"PrimitiveTopology", i32 3} !300 = !{!"MaxNumOfPrimitives", i32 0} !301 = !{!"MaxNumOfVertices", i32 0} !302 = !{!"MaxNumOfPerPrimitiveOutputs", i32 0} !303 = !{!"MaxNumOfPerVertexOutputs", i32 0} !304 = !{!"WorkGroupSize", i32 0} !305 = !{!"WorkGroupMemorySizeInBytes", i32 0} !306 = !{!"IndexFormat", i32 6} !307 = !{!"SubgroupSize", i32 0} !308 = !{!"taskInfo", !309, !304, !305, !307} !309 = !{!"MaxNumOfOutputs", i32 0} !310 = !{!"NBarrierCnt", i32 0} !311 = !{!"rtInfo", !312, !313, !314, !315, !316, !317, !318, !319, !320, !321, !322, !323} !312 = !{!"RayQueryAllocSizeInBytes", i32 0} !313 = !{!"NumContinuations", i32 0} !314 = !{!"RTAsyncStackAddrspace", i32 -1} !315 = !{!"RTAsyncStackSurfaceStateOffset", i1* null} !316 = !{!"SWHotZoneAddrspace", i32 -1} !317 = !{!"SWHotZoneSurfaceStateOffset", i1* null} !318 = !{!"SWStackAddrspace", i32 -1} !319 = !{!"SWStackSurfaceStateOffset", i1* null} !320 = !{!"RTSyncStackAddrspace", i32 -1} !321 = !{!"RTSyncStackSurfaceStateOffset", i1* null} !322 = !{!"doSyncDispatchRays", i1 false} !323 = !{!"MemStyle", !"Xe"} !324 = !{!"CurUniqueIndirectIdx", i32 0} !325 = !{!"inlineDynTextures"} !326 = !{!"inlineResInfoData"} !327 = !{!"immConstant", !328, !329, !330} !328 = !{!"data"} !329 = !{!"sizes"} !330 = !{!"zeroIdxs"} !331 = !{!"stringConstants"} !332 = !{!"inlineConstantBuffers"} !333 = !{!"inlineGlobalBuffers"} !334 = !{!"GlobalPointerProgramBinaryInfos"} !335 = !{!"ConstantPointerProgramBinaryInfos"} !336 = !{!"GlobalBufferAddressRelocInfo"} !337 = !{!"ConstantBufferAddressRelocInfo"} !338 = !{!"forceLscCacheList"} !339 = !{!"SrvMap"} !340 = !{!"RasterizerOrderedByteAddressBuffer"} !341 = !{!"RasterizerOrderedViews"} !342 = !{!"MinNOSPushConstantSize", i32 0} !343 = !{!"inlineProgramScopeOffsets"} !344 = !{!"shaderData", !345} !345 = !{!"numReplicas", i32 0} !346 = !{!"URBInfo", !347, !348, !349} !347 = !{!"has64BVertexHeaderInput", i1 false} !348 = !{!"has64BVertexHeaderOutput", i1 false} !349 = !{!"hasVertexHeader", i1 true} !350 = !{!"UseBindlessImage", i1 false} !351 = !{!"enableRangeReduce", i1 false} !352 = !{!"allowMatchMadOptimizationforVS", i1 false} !353 = !{!"disableMatchMadOptimizationForCS", i1 false} !354 = !{!"disableMemOptforNegativeOffsetLoads", i1 false} !355 = !{!"enableThreeWayLoadSpiltOpt", i1 false} !356 = !{!"statefulResourcesNotAliased", i1 false} !357 = !{!"disableMixMode", i1 false} !358 = !{!"genericAccessesResolved", i1 false} !359 = !{!"PrivateMemoryPerFG"} !360 = !{!"m_OptsToDisable"} !361 = !{!"capabilities", !362} !362 = !{!"globalVariableDecorationsINTEL", i1 false} !363 = !{!"m_ShaderResourceViewMcsMask", !364, !365} !364 = !{!"m_ShaderResourceViewMcsMaskVec[0]", i64 0} !365 = !{!"m_ShaderResourceViewMcsMaskVec[1]", i64 0} !366 = !{!"computedDepthMode", i32 0} !367 = !{!"isHDCFastClearShader", i1 false} !368 = !{i32 2, i32 0} !369 = !{!"clang version 14.0.5"} !370 = distinct !DISubprogram(name: "matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c", linkageName: "matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c", scope: null, file: !4, line: 55, type: !371, scopeLine: 55, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !3, templateParams: !373, retainedNodes: !373) !371 = !DISubroutineType(types: !372) !372 = !{null} !373 = !{} !374 = !DILocation(line: 42, column: 22, scope: !375, inlinedAt: !377) !375 = !DILexicalBlockFile(scope: !370, file: !376, discriminator: 0) !376 = !DIFile(filename: "standard.py", directory: "/home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language") !377 = distinct !DILocation(line: 64, scope: !370) !378 = !DILocation(line: 42, column: 28, scope: !375, inlinedAt: !377) !379 = !DILocation(line: 42, column: 22, scope: !375, inlinedAt: !380) !380 = distinct !DILocation(line: 65, scope: !370) !381 = !DILocation(line: 42, column: 28, scope: !375, inlinedAt: !380) !382 = !DILocation(line: 66, column: 38, scope: !370) !383 = !DILocation(line: 67, column: 22, scope: !370) !384 = !DILocation(line: 68, column: 29, scope: !370) !385 = !DILocation(line: 69, column: 35, scope: !370) !386 = !DILocation(line: 136, column: 26, scope: !375, inlinedAt: !387) !387 = distinct !DILocation(line: 69, scope: !370) !388 = !DILocation(line: 70, column: 33, scope: !370) !389 = !DILocation(line: 70, column: 27, scope: !370) !390 = !DILocation(line: 71, column: 19, scope: !370) !391 = !DILocation(line: 71, column: 40, scope: !370) !392 = !DILocation(line: 72, column: 29, scope: !370) !393 = !DILocation(line: 73, column: 29, scope: !370) !394 = !DILocation(line: 82, column: 20, scope: !370) !395 = !DILocation(line: 83, column: 20, scope: !370) !396 = !DILocation(line: 81, column: 25, scope: !370) !397 = !DILocation(line: 92, column: 26, scope: !370) !398 = !DILocation(line: 85, column: 44, scope: !370) !399 = !DILocation(line: 86, column: 44, scope: !370) !400 = !DILocation(line: 84, column: 33, scope: !370) !401 = !DILocation(line: 92, column: 4, scope: !370) ```
Disassemble of the simple GEMM kernel ``` //.kernel matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c //.platform PVCXT //.thread_config numGRF=128, numAcc=4, numSWSB=16 //.options_string "-emitCrossThreadOffR0Reloc " //.full_options "-emitLocation -enableCoalesceScalarMoves -hasRNEandDenorm -noStitchExternFunc -emitCrossThreadOffR0Reloc -linker 63 -preserver0 -abortOnSpill 4 -enableBundleCR 3 -boundsChecking -presched-rp 100 -nodpsendreorder -SBIDDepLoc -output -binary -dumpcommonisa -dumpcombinedcisa -dumpvisa -printHexFloatInAsm -noverifyCISA -enableHalfLSC -partialInt64 -generateDebugInfo " //.instCount 389 //.RA type GRAPH_COLORING_FF_BC_RA //.git-hash 80055716ef5675b711e3f3198c61b0ff7dc7208d //.declare BuiltInR0 (0) rf=r size=64 type=ud align=32 words (r0.0) IsBuiltin //.declare (1) rf=r size=64 type=ud alias=BuiltInR0+0 align=32 words (r0.0) IsBuiltin //.declare BuiltinA0 (2) rf=a size=4 type=ud align=1 words (a0.0) IsBuiltin //.declare BuiltinA0Dot2 (3) rf=a size=4 type=ud align=1 words (a0.2) IsBuiltin //.declare %null (9) rf=r size=4 type=ud align=32 words //.declare %local_id_x (12) rf=r size=4 type=ud align=2 words (r1.10) //.declare %local_id_y (13) rf=r size=4 type=ud align=2 words (r1.11) //.declare %local_size_x (14) rf=r size=4 type=ud align=2 words (r1.6) //.declare %local_size_y (15) rf=r size=4 type=ud align=2 words (r1.7) //.declare %group_id_x (16) rf=r size=4 type=ud align=2 words (r0.1) //.declare %group_id_y (17) rf=r size=4 type=ud align=2 words (r0.6) //.declare %group_id_z (18) rf=r size=4 type=ud align=2 words (r0.7) //.declare %group_count_x (19) rf=r size=4 type=ud align=2 words (r1.8) //.declare %group_count_y (20) rf=r size=4 type=ud align=2 words (r1.9) //.declare %tsc (21) rf=r size=20 type=ud align=2 words //.declare %arg (22) rf=r size=0 type=ud align=32 words (r26.0) //.declare %retval (23) rf=r size=0 type=ud align=32 words (r26.0) Output //.declare %sp (24) rf=r size=8 type=uq align=32 words (r125.3) //.declare %fp (25) rf=r size=8 type=uq align=32 words (r125.2) //.declare %sr0 (26) rf=r size=16 type=ud align=2 words //.declare %cr0 (27) rf=r size=12 type=ud align=2 words //.declare %ce0 (28) rf=r size=4 type=ud align=2 words //.declare %dbg0 (29) rf=r size=8 type=ud align=2 words //.declare implBufPtr (31) rf=r size=8 type=uq align=32 words (r126.0) //.declare localIdBufPtr (32) rf=r size=8 type=uq align=32 words (r126.3) //.declare %msg0 (33) rf=r size=12 type=ud align=2 words //.declare V0033 (41) rf=r size=64 type=d alias=+0 align=32 words (r0.0) //.declare V0034 (42) rf=r size=8 type=uq align=4 words (r4.4) //.declare V0035 (43) rf=r size=8 type=uq align=4 words (r4.5) //.declare V0036 (44) rf=r size=8 type=uq align=4 words (r4.6) //.declare V0037 (45) rf=r size=4 type=d align=2 words (r4.14) //.declare V0038 (46) rf=r size=4 type=d align=2 words (r4.15) //.declare V0039 (47) rf=r size=4 type=d align=2 words (r5.0) //.declare V0040 (48) rf=r size=4 type=d align=2 words (r5.1) //.declare V0041 (49) rf=r size=4 type=d align=2 words (r5.2) //.declare V0042 (50) rf=r size=4 type=d align=2 words (r5.3) //.declare V0044 (52) rf=r size=32 type=d alias=+0 align=32 words (r0.0) //.declare V0045 (53) rf=r size=32 type=d align=16 words (r4.0) //.declare V0046 (54) rf=r size=32 type=w align=16 words (r1.0) //.declare V0047 (55) rf=r size=32 type=w align=16 words (r2.0) //.declare V0048 (56) rf=r size=32 type=w align=16 words (r3.0) //.declare V0049 (57) rf=r size=8 type=uq align=4 words (r5.2) //.declare V0053 (61) rf=r size=4 type=d align=2 words (r1.14) //.declare P01 (62) rf=f16 size=2 type=uw align=1 words (f1.0) //.declare V0054 (63) rf=r size=4 type=d align=2 words (r1.8) //.declare P02 (64) rf=f16 size=2 type=uw align=1 words (f0.1) //.declare V0055 (65) rf=r size=4 type=d align=2 words (r1.15) //.declare V0056 (66) rf=r size=4 type=d align=2 words (r1.8) //.declare V0057 (67) rf=r size=4 type=d align=2 words (r2.7) //.declare P03 (68) rf=f16 size=2 type=uw align=1 words (f2.1) //.declare V0058 (69) rf=r size=4 type=d align=2 words (r2.6) //.declare V0059 (70) rf=r size=4 type=d align=2 words (r1.12) //.declare V0060 (71) rf=r size=4 type=d align=2 words (r2.2) //.declare V0061 (72) rf=r size=4 type=d align=2 words (r1.8) //.declare V0062 (73) rf=r size=4 type=d align=2 words (r1.9) //.declare V0063 (74) rf=r size=4 type=d align=2 words (r1.8) //.declare V0064 (75) rf=r size=4 type=d align=2 words (r1.13) //.declare V0065 (76) rf=r size=4 type=f align=2 words (r1.8) //.declare V0066 (77) rf=r size=4 type=ud alias=V0062+0 align=2 words (r1.9) //.declare V0067 (78) rf=r size=4 type=f align=2 words (r1.10) //.declare V0068 (79) rf=r size=8 type=df align=4 words (r1.4) //.declare V0069 (80) rf=r size=8 type=df align=4 words (r2.0) //.declare V0070 (81) rf=r size=8 type=df align=4 words (r1.5) //.declare V0071 (82) rf=r size=8 type=df align=4 words (r2.2) //.declare V0072 (83) rf=r size=8 type=df align=4 words (r1.4) //.declare V0073 (84) rf=r size=4 type=ud alias=V0064+0 align=2 words (r1.13) //.declare V0074 (85) rf=r size=8 type=df align=4 words (r2.0) //.declare V0075 (86) rf=r size=8 type=df align=4 words (r1.4) //.declare V0076 (87) rf=r size=4 type=d align=4 words (r1.10) //.declare V0077 (88) rf=r size=4 type=ud alias=V0076+0 align=2 words (r1.10) //.declare V0078 (89) rf=r size=4 type=d align=2 words (r1.8) //.declare V0079 (90) rf=r size=4 type=d align=2 words (r1.8) //.declare V0080 (91) rf=r size=4 type=d align=2 words (r1.8) //.declare V0081 (92) rf=r size=4 type=d align=2 words (r1.8) //.declare V0082 (93) rf=r size=4 type=d alias=+0 align=2 words (r1.12) //.declare P04 (94) rf=f16 size=2 type=uw align=1 words (f2.0) //.declare V0083 (95) rf=r size=4 type=d align=2 words (r2.5) //.declare V0084 (96) rf=r size=4 type=d align=2 words (r1.15) //.declare V0085 (97) rf=r size=4 type=d align=2 words (r2.0) //.declare V0086 (98) rf=r size=4 type=d align=2 words (r1.8) //.declare V0087 (99) rf=r size=4 type=d align=2 words (r2.4) //.declare V0088 (100) rf=r size=4 type=f align=2 words (r1.8) //.declare V0089 (101) rf=r size=4 type=ud alias=V0085+0 align=2 words (r2.0) //.declare V0090 (102) rf=r size=4 type=f align=2 words (r2.2) //.declare V0091 (103) rf=r size=8 type=df align=4 words (r1.4) //.declare V0092 (104) rf=r size=8 type=df align=4 words (r2.0) //.declare V0093 (105) rf=r size=8 type=df align=4 words (r1.5) //.declare V0094 (106) rf=r size=8 type=df align=4 words (r2.1) //.declare V0095 (107) rf=r size=8 type=df align=4 words (r5.2) //.declare V0096 (108) rf=r size=4 type=ud alias=V0087+0 align=2 words (r2.4) //.declare V0097 (109) rf=r size=8 type=df align=4 words (r2.0) //.declare V0098 (110) rf=r size=8 type=df align=4 words (r5.2) //.declare V0099 (111) rf=r size=4 type=d align=4 words (r2.2) //.declare V0100 (112) rf=r size=4 type=ud alias=V0099+0 align=2 words (r2.2) //.declare V0101 (113) rf=r size=4 type=d align=32 words (r2.0) //.declare V0102 (114) rf=r size=4 type=d align=2 words (r1.8) //.declare V0103 (115) rf=r size=4 type=d align=2 words (r2.5) //.declare V0104 (116) rf=r size=4 type=d alias=+4 align=2 words (r1.13) //.declare V0105 (117) rf=r size=4 type=d align=2 words (r5.4) //.declare V0106 (118) rf=r size=4 type=d align=2 words (r1.15) //.declare V0107 (119) rf=r size=4 type=d align=2 words (r1.8) //.declare V0108 (120) rf=r size=4 type=d align=2 words (r1.14) //.declare V0109 (121) rf=r size=4 type=d align=2 words (r1.8) //.declare V0110 (122) rf=r size=4 type=d align=2 words (r2.4) //.declare V0111 (123) rf=r size=4 type=f align=2 words (r1.8) //.declare V0112 (124) rf=r size=4 type=ud alias=V0108+0 align=2 words (r1.14) //.declare V0113 (125) rf=r size=4 type=f align=2 words (r2.2) //.declare V0114 (126) rf=r size=8 type=df align=4 words (r1.4) //.declare V0115 (127) rf=r size=8 type=df align=4 words (r2.0) //.declare V0116 (128) rf=r size=8 type=df align=4 words (r1.5) //.declare V0117 (129) rf=r size=8 type=df align=4 words (r2.1) //.declare V0118 (130) rf=r size=8 type=df align=4 words (r5.2) //.declare V0119 (131) rf=r size=4 type=ud alias=V0110+0 align=2 words (r2.4) //.declare V0120 (132) rf=r size=8 type=df align=4 words (r2.0) //.declare V0121 (133) rf=r size=8 type=df align=4 words (r5.2) //.declare V0122 (134) rf=r size=4 type=d align=4 words (r5.4) //.declare V0123 (135) rf=r size=4 type=ud alias=V0122+0 align=2 words (r5.4) //.declare V0124 (136) rf=r size=4 type=d align=32 words (r2.0) //.declare V0125 (137) rf=r size=4 type=d align=2 words (r1.8) //.declare V0126 (138) rf=r size=4 type=d align=2 words (r117.2) //.declare V0127 (139) rf=r size=4 type=d alias=+0 align=2 words (r2.0) //.declare V0128 (140) rf=r size=4 type=d alias=+4 align=2 words (r2.1) //.declare V0129 (141) rf=r size=4 type=d align=2 words (r1.8) //.declare V0130 (142) rf=r size=4 type=d align=2 words (r1.9) //.declare V0131 (143) rf=r size=4 type=d align=2 words (r1.8) //.declare V0132 (144) rf=r size=4 type=d align=2 words (r1.14) //.declare V0133 (145) rf=r size=4 type=f align=2 words (r1.8) //.declare V0134 (146) rf=r size=4 type=ud alias=V0130+0 align=2 words (r1.9) //.declare V0135 (147) rf=r size=4 type=f align=2 words (r1.10) //.declare V0136 (148) rf=r size=8 type=df align=4 words (r1.4) //.declare V0137 (149) rf=r size=8 type=df align=4 words (r2.1) //.declare V0138 (150) rf=r size=8 type=df align=4 words (r1.5) //.declare V0139 (151) rf=r size=8 type=df align=4 words (r1.6) //.declare V0140 (152) rf=r size=8 type=df align=4 words (r1.4) //.declare V0141 (153) rf=r size=4 type=ud alias=V0132+0 align=2 words (r1.14) //.declare V0142 (154) rf=r size=8 type=df align=4 words (r2.1) //.declare V0143 (155) rf=r size=8 type=df align=4 words (r1.4) //.declare V0144 (156) rf=r size=4 type=d align=4 words (r1.10) //.declare V0145 (157) rf=r size=4 type=ud alias=V0144+0 align=2 words (r1.10) //.declare V0146 (158) rf=r size=4 type=d align=2 words (r1.8) //.declare V0147 (159) rf=r size=4 type=d align=2 words (r1.8) //.declare V0148 (160) rf=r size=4 type=d align=2 words (r1.8) //.declare V0149 (161) rf=r size=4 type=d align=2 words (r6.0) //.declare V0150 (162) rf=r size=64 type=d align=32 words (r9.0) //.declare V0151 (163) rf=r size=32 type=uw alias=V0046+0 align=32 words (r1.0) //.declare V0152 (164) rf=r size=64 type=d align=32 words (r119.0) //.declare V0153 (165) rf=r size=64 type=d align=32 words (r2.0) //.declare V0154 (166) rf=r size=4 type=d align=2 words (r2.0) //.declare V0155 (167) rf=r size=64 type=ud alias=V0150+0 align=32 words (r9.0) //.declare V0156 (168) rf=r size=64 type=d align=32 words (r1.0) //.declare V0157 (169) rf=r size=4 type=d align=2 words (r5.4) //.declare V0158 (170) rf=r size=4 type=d alias=+4 align=2 words (r117.1) //.declare P05 (171) rf=f16 size=2 type=uw align=1 words (f0.0) //.declare V0159 (172) rf=r size=64 type=f align=32 words (r118.0) //.declare V0160 (173) rf=r size=64 type=f align=32 words (r111.0) //.declare V0161 (174) rf=r size=64 type=f align=32 words (r110.0) //.declare V0162 (175) rf=r size=64 type=f align=32 words (r109.0) //.declare V0163 (176) rf=r size=64 type=f align=32 words (r108.0) //.declare V0164 (177) rf=r size=64 type=f align=32 words (r107.0) //.declare V0165 (178) rf=r size=64 type=f align=32 words (r106.0) //.declare V0166 (179) rf=r size=64 type=f align=32 words (r105.0) //.declare V0167 (180) rf=r size=64 type=f align=32 words (r48.0) //.declare V0168 (181) rf=r size=64 type=f align=32 words (r116.0) //.declare V0169 (182) rf=r size=64 type=f align=32 words (r115.0) //.declare V0170 (183) rf=r size=64 type=f align=32 words (r114.0) //.declare V0171 (184) rf=r size=64 type=f align=32 words (r113.0) //.declare V0172 (185) rf=r size=64 type=f align=32 words (r112.0) //.declare V0173 (186) rf=r size=64 type=f align=32 words (r46.0) //.declare V0174 (187) rf=r size=64 type=f align=32 words (r3.0) //.declare V0175 (188) rf=r size=4 type=d align=2 words (r5.4) //.declare V0176 (189) rf=r size=4 type=d alias=+0 align=2 words (r117.0) //.declare V0177 (190) rf=r size=4 type=d align=2 words (r5.1) //.declare V0178 (191) rf=r size=4 type=d align=2 words (r5.6) //.declare V0179 (192) rf=r size=4 type=d align=2 words (r5.1) //.declare V0180 (193) rf=r size=4 type=d alias=+0 align=2 words (r5.8) //.declare V0181 (194) rf=r size=4 type=d alias=+4 align=2 words (r5.9) //.declare V0182 (195) rf=r size=4 type=d align=2 words (r5.1) //.declare V0183 (196) rf=r size=4 type=d align=2 words (r5.1) //.declare V0184 (197) rf=r size=8 type=q alias=V0034+0 align=32 words (r4.4) //.declare V0185 (198) rf=r size=8 type=q alias=V0035+0 align=32 words (r4.5) //.declare V0186 (199) rf=r size=64 type=d align=32 words (r47.0) //.declare V0187 (200) rf=r size=4 type=d alias=+0 align=2 words (r5.4) //.declare V0188 (201) rf=r size=4 type=d alias=+4 align=2 words (r5.5) //.declare V0189 (202) rf=r size=4 type=d align=2 words (r5.2) //.declare V0190 (203) rf=r size=256 type=w align=32 words (r61.0) //.declare (204) rf=r size=64 type=ud align=32 words (r6.0) //.declare (205) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0191 (206) rf=r size=4 type=d align=2 words (r5.7) //.declare V0192 (207) rf=r size=256 type=w align=32 words (r53.0) //.declare (208) rf=r size=64 type=ud align=32 words (r6.0) //.declare (209) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0193 (210) rf=r size=4 type=d align=2 words (r5.7) //.declare V0194 (211) rf=r size=256 type=w align=32 words (r57.0) //.declare (212) rf=r size=64 type=ud align=32 words (r6.0) //.declare (213) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0195 (214) rf=r size=4 type=d align=2 words (r5.7) //.declare V0196 (215) rf=r size=256 type=w align=32 words (r49.0) //.declare (216) rf=r size=64 type=ud align=32 words (r6.0) //.declare (217) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0197 (218) rf=r size=512 type=d align=32 words (r89.0) //.declare (219) rf=r size=64 type=ud align=32 words (r6.0) //.declare (220) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0198 (221) rf=r size=4 type=d align=2 words (r5.7) //.declare V0199 (222) rf=r size=512 type=d align=32 words (r97.0) //.declare (223) rf=r size=64 type=ud align=32 words (r6.0) //.declare (224) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0200 (225) rf=r size=4 type=d align=2 words (r5.10) //.declare V0201 (226) rf=r size=512 type=d align=32 words (r81.0) //.declare (227) rf=r size=64 type=ud align=32 words (r6.0) //.declare (228) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0202 (229) rf=r size=4 type=d align=2 words (r5.11) //.declare V0203 (230) rf=r size=512 type=d align=32 words (r73.0) //.declare (231) rf=r size=64 type=ud align=32 words (r6.0) //.declare (232) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0204 (233) rf=r size=512 type=d align=32 words (r65.0) //.declare (234) rf=r size=64 type=ud align=32 words (r6.0) //.declare (235) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0205 (236) rf=r size=512 type=d align=32 words (r30.0) //.declare (237) rf=r size=64 type=ud align=32 words (r6.0) //.declare (238) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0206 (239) rf=r size=512 type=d align=32 words (r22.0) //.declare (240) rf=r size=64 type=ud align=32 words (r6.0) //.declare (241) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0207 (242) rf=r size=512 type=d align=32 words (r14.0) //.declare (243) rf=r size=64 type=ud align=32 words (r6.0) //.declare (244) rf=r size=64 type=uq alias=+0 align=32 words (r6.0) //.declare V0208 (245) rf=r size=512 type=f align=32 words (r6.0) //.declare V0209 (246) rf=r size=512 type=f align=32 words (r38.0) //.declare V0210 (247) rf=r size=512 type=f align=32 words (r6.0) //.declare V0211 (248) rf=r size=256 type=ud alias=V0190+0 align=32 words (r61.0) //.declare V0212 (249) rf=r size=512 type=f align=32 words (r38.0) //.declare V0213 (250) rf=r size=256 type=ud alias=V0192+0 align=32 words (r53.0) //.declare V0214 (251) rf=r size=256 type=ud alias=V0194+0 align=32 words (r57.0) //.declare V0215 (252) rf=r size=256 type=ud alias=V0196+0 align=32 words (r49.0) //.declare P06 (253) rf=f16 size=2 type=uw align=1 words (f1.1) //.declare V0216 (254) rf=r size=4 type=d align=2 words (r4.0) //.declare V0217 (255) rf=r size=4 type=d align=2 words (r4.1) //.declare V0218 (256) rf=r size=4 type=d align=2 words (r4.0) //.declare V0219 (257) rf=r size=4 type=d align=2 words (r4.0) //.declare V0220 (258) rf=r size=512 type=f align=32 words (r5.0) //.declare V0221 (259) rf=r size=512 type=d alias=V0220+0 align=32 words (r5.0) //.declare V0222 (260) rf=r size=8 type=q alias=V0036+0 align=32 words (r4.6) //.declare V0223 (261) rf=r size=64 type=d align=32 words (r13.0) //.declare (262) rf=r size=64 type=ud align=32 words (r5.0) //.declare (263) rf=r size=64 type=uq alias=+0 align=32 words (r5.0) //.declare V0224 (264) rf=r size=64 type=d align=32 words (r2.0) //.declare V0225 (265) rf=r size=64 type=d align=32 words (r14.0) //.declare V0226 (266) rf=r size=512 type=f align=32 words (r5.0) //.declare V0227 (267) rf=r size=512 type=d alias=V0226+0 align=32 words (r5.0) //.declare (269) rf=r size=64 type=ud align=32 words (r2.0) //.declare (270) rf=r size=64 type=uq alias=+0 align=32 words (r2.0) //.declare (271) rf=r size=64 type=ud align=32 words (r112.0) //.declare (272) rf=r size=8 type=d align=8 words (r2.0) //.declare (273) rf=r size=8 type=d align=8 words (r1.12) //.declare (274) rf=r size=8 type=d align=8 words (r5.4) //.declare (275) rf=r size=8 type=d align=8 words (r117.0) //.declare (276) rf=r size=8 type=d align=8 words (r5.8) //.declare (277) rf=r size=8 type=df align=4 words (r1.4) //.declare (278) rf=r size=8 type=df align=4 words (r1.4) //.declare (279) rf=r size=4 type=d align=2 words (r2.0) //.declare (280) rf=r size=8 type=df align=4 words (r1.4) //.declare (281) rf=r size=8 type=df align=4 words (r1.4) //.declare r0 (400) rf=r size=64 type=ud align=32 words (r0.0) //.declare rtmp (401) rf=r size=64 type=ud align=32 words (r127.0) //.declare (402) rf=r size=128 type=ud align=32 words (r1.0) //.declare (403) rf=r size=4 type=ud align=2 words (r126.0) //.declare (404) rf=r size=64 type=ud align=32 words (r3.0) //.declare (405) rf=r size=64 type=ud align=32 words (r4.0) //.declare (406) rf=r size=4 type=ud align=2 words (r126.0) //.declare (407) rf=r size=32 type=ud align=2 words (r5.0) // .inputs // +----------+----------+--------+----------+------------------+ // | id | type | bytes | at | from | // +----------+----------+--------+----------+------------------+ // | V0046 | :w x 16 | 0x20 | r1 | pti[tid]+0x0 | // | V0047 | :w x 16 | 0x20 | r2 | pti[tid]+0x40 | // | V0048 | :w x 16 | 0x20 | r3 | pti[tid]+0x80 | // | V0045 | :d x 8 | 0x20 | r4 | cti+0x0 | // | V0034 | :uq | 0x8 | r4+0x20 | cti+0x20 | // | V0035 | :uq | 0x8 | r4+0x28 | cti+0x28 | // | V0036 | :uq | 0x8 | r4+0x30 | cti+0x30 | // | V0037 | :d | 0x4 | r4+0x38 | cti+0x38 | // | V0038 | :d | 0x4 | r4+0x3C | cti+0x3C | // | V0039 | :d | 0x4 | r5 | cti+0x40 | // | V0040 | :d | 0x4 | r5+0x4 | cti+0x44 | // | V0041 | :d | 0x4 | r5+0x8 | cti+0x48 | // | V0042 | :d | 0x4 | r5+0xC | cti+0x4C | // | V0049 | :uq | 0x8 | r5+0x10 | cti+0x50 | // +----------+----------+--------+----------+------------------+ // B000: Preds:{}, Succs:{B001} per_thread_prolog: (W) mov (16|M0) r127.0<1>:ud 0x0:ud // ALU pipe: int; (W) and (1|M0) r127.2<1>:ud r0.0<0;1,0>:ud 0xFFFFFFC0:ud // ALU pipe: int; (W) and (1|M0) r127.0<1>:uw r0.4<0;1,0>:uw 0xFF:uw // ALU pipe: int; (W) add (1|M0) r127.2<1>:ud r127.2<0;1,0>:ud 0x60:ud {I@2} // ALU pipe: int; (W) add (1|M0) r127.2<1>:ud r127.2<0;1,0>:ud 0x0:ud {I@1} // R_SYM_ADDR_32: __INTEL_PATCH_CROSS_THREAD_OFFSET_OFF_R0; ALU pipe: int; (W) mad (1|M0) r127.0<1>:ud r127.2<0;0>:ud r127.0<0;0>:uw 0xC0:uw {I@1} // ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials/10-experimental-tma-store-matrix-multiplication.py // Line 92: tl.store(c_block_ptr, accumulator) (W) load.ugm.d32x32t.a32.ca.ca (1|M0) r1:2 bti[255][r127:1] {A@1,$0} // ex_desc:0xFF000000; desc:0x6228E500 // (W) add (1|M0) r126.0<1>:ud r127.0<0;1,0>:ud 0x80:uw // ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials/10-experimental-tma-store-matrix-multiplication.py // Line 92: tl.store(c_block_ptr, accumulator) (W) load.ugm.d32x16t.a32.ca.ca (1|M0) r3:1 bti[255][r126:1] {A@1,$1} // ex_desc:0xFF000000; desc:0x6218D500 // nop // nop // nop // // B001: Preds:{B000}, Succs:{B002} // cross_thread_prolog: (W) and (1|M0) r127.0<1>:ud r0.0<0;1,0>:ud 0xFFFFFFC0:ud {$0.src} // ALU pipe: int; (W) add (1|M0) r127.0<1>:ud r127.0<0;1,0>:ud 0x0:ud {I@1} // R_SYM_ADDR_32: __INTEL_PATCH_CROSS_THREAD_OFFSET_OFF_R0; ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials/10-experimental-tma-store-matrix-multiplication.py // Line 92: tl.store(c_block_ptr, accumulator) (W) load.ugm.d32x16t.a32.ca.ca (1|M0) r4:1 bti[255][r127:1] {I@1,$2} // ex_desc:0xFF000000; desc:0x6218D500 // (W) add (1|M0) r126.0<1>:ud r127.0<0;1,0>:ud 0x40:uw {$1.src} // ALU pipe: int; (W) load.ugm.d32x8t.a32.ca.ca (1|M0) r5:1 bti[255][r126:1] {I@1,$3} // ex_desc:0xFF000000; desc:0x6218C500 // // B002: Preds:{B001}, Succs:{B003, B004} // _main: (W) or (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x4C0:uw {Compacted,A@1} // $1 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 42: return (x + div - 1) // div sync.nop null {Compacted,A@1} // $4 sync.nop null {Compacted,$2.dst} // $4 (W) add (1|M0) r1.14<1>:d r4.15<0;1,0>:d 127:w {A@1,$0.dst} // ALU pipe: int; $4 (W) cmp (16|M0) (lt)f1.0 null<1>:d r1.14<0;1,0>:d 0:w {I@1} // ALU pipe: int; $5 (W&~f1.0) jmpi _0_025 // ALU pipe: int; $6 // B003: Preds:{B002}, Succs:{B004} _0_026: (W) add (1|M0) r1.14<1>:d r4.15<0;1,0>:d 254:w // ALU pipe: int; $8 // B004: Preds:{B003, B002}, Succs:{B005, B006} _0_025: (W) add (1|M0) r1.8<1>:d r4.14<0;1,0>:d 127:w // ALU pipe: int; $10 (W) cmp (16|M0) (lt)f0.1 null<1>:d r1.8<0;1,0>:d 0:w {I@1} // ALU pipe: int; $11 (W&~f0.1) jmpi _0_027 // ALU pipe: int; $12 // B005: Preds:{B004}, Succs:{B006} _0_028: (W) add (1|M0) r1.8<1>:d r4.14<0;1,0>:d 254:w // ALU pipe: int; $14 // B006: Preds:{B005, B004}, Succs:{B007, B008} _0_027: (W) asr (1|M0) r1.15<1>:d r1.8<0;1,0>:d 7:w {I@1} // ALU pipe: int; $16 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials/10-experimental-tma-store-matrix-multiplication.py // Line 66: num_pid_in_group = GROUP_SIZE_M * num_pid_n (W) asr (1|M0) r1.8<1>:d r1.14<0;1,0>:d 4:w // ALU pipe: int; $19 (W) and (1|M0) r2.7<1>:d r1.8<0;1,0>:d -8:w {I@1} // ALU pipe: int; $20 // Line 67: group_id = pid // num_pid_in_group (W) cmp (16|M0) (eq)f2.1 null<1>:d r2.7<0;1,0>:d 0:w {I@1} // ALU pipe: int; $22 (W&~f2.1) jmpi _0_029 // ALU pipe: int; $23 // B007: Preds:{B006}, Succs:{B009} _0_030: (W) mov (1|M0) r2.6<1>:d -8:w // ALU pipe: int; $25 (W) jmpi _0_031 // $26 // B008: Preds:{B006}, Succs:{B009} _0_029: (W) asr (1|M0) r1.12<1>:d r1.14<0;1,0>:d 31:w // ALU pipe: int; $28 (W) asr (1|M0) r2.2<1>:d r0.1<0;1,0>:d 31:w // ALU pipe: int; $29 (W) add (1|M0) r1.8<1>:d r1.12<0;1,0>:d r2.7<0;1,0>:d {I@2} // ALU pipe: int; $30 (W) xor (1|M0) r1.9<1>:d r1.8<0;1,0>:d r1.12<0;1,0>:d {I@1} // ALU pipe: int; $31 (W) add (1|M0) r1.8<1>:d r2.2<0;1,0>:d r0.1<0;1,0>:d // ALU pipe: int; $32 (W) xor (1|M0) r1.13<1>:d r1.8<0;1,0>:d r2.2<0;1,0>:d {I@1} // ALU pipe: int; $33 (W) mov (1|M0) r1.8<1>:f r1.9<0;1,0>:ud {I@1} // ALU pipe: float; $34 (W) math.inv (1|M0) r1.10<1>:f r1.8<0;1,0>:f {F@1} // ALU pipe: math; $35 (W) mov (1|M0) r1.4<1>:df r1.9<0;1,0>:ud {M@1} // ALU pipe: long; $36 (W) mov (1|M0) r1.5<1>:df r1.10<0;1,0>:f // ALU pipe: long; $38 (W) mov (1|M0) r2.0<1>:df -r1.4<0;1,0>:df {L@2} // ALU pipe: long; $37 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $39 (W) mov (1|M0) r1.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $40 (W) mad (1|M0) r2.2<1>:df r1.4<0;0>:df r1.5<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $40 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $41 (W) mov (1|M0) r1.4<1>:df r1.13<0;1,0>:ud {A@1} // ALU pipe: long; $42 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $43 (W) mul (1|M0) r2.0<1>:df r1.5<0;1,0>:df r1.4<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $44 (W) mad (1|M0) r1.4<1>:df r2.0<0;0>:df r2.2<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $45 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $46 (W) mov (1|M0) r1.10<2>:ud r1.4<0;1,0>:df {A@1} // ALU pipe: int; $47 (W) xor (1|M0) r1.8<1>:d r1.12<0;1,0>:d r2.2<0;1,0>:d // ALU pipe: int; $48 (W) add (1|M0) r1.8<1>:d r1.8<0;1,0>:d r1.10<0;1,0>:d {I@1} // ALU pipe: int; $49 (W) bfn.(s0^s1^s2) (1|M0) r1.8<1>:ud r1.8<0;0>:ud r1.12<0;0>:ud r2.2<0>:ud {I@1} // ALU pipe: int; $50 (W) shl (1|M0) r2.6<1>:d r1.8<0;1,0>:d 3:w {I@1} // ALU pipe: int; $51 // B009: Preds:{B008, B007}, Succs:{B010, B011} _0_031: // Line 69: group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) (W) add (1|M0) r1.8<1>:d r1.15<0;1,0>:d -r2.6<0;1,0>:d {I@1} // ALU pipe: int; $54 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 136: return math.min(x, y, propagate_nan) (W) sel (1|M0) (lt)f0.0 r1.12<1>:d r1.8<0;1,0>:d 8:w {I@1} // ALU pipe: int; $57 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/./tutorials/10-experimental-tma-store-matrix-multiplication.py // Line 70: pid_m = first_pid_m + (pid % group_size_m) (W) cmp (16|M0) (eq)f2.0 null<1>:d r1.12<0;1,0>:d 0:w {I@1} // ALU pipe: int; $60 (W&~f2.0) jmpi _0_032 // ALU pipe: int; $61 // B010: Preds:{B009}, Succs:{B012} _0_033: (W) mov (1|M0) r2.5<1>:d -1:w // ALU pipe: int; $63 (W) jmpi _0_034 // $64 // B011: Preds:{B009}, Succs:{B012} _0_032: (W) asr (1|M0) r1.15<1>:d r0.1<0;1,0>:d 31:w // ALU pipe: int; $66 (W) mov (1|M0) r2.0<1>:d (abs)r1.12<0;1,0>:d // ALU pipe: int; $67 (W) add (1|M0) r1.8<1>:d r1.15<0;1,0>:d r0.1<0;1,0>:d {I@2} // ALU pipe: int; $68 (W) xor (1|M0) r2.4<1>:d r1.8<0;1,0>:d r1.15<0;1,0>:d {I@1} // ALU pipe: int; $69 (W) mov (1|M0) r1.8<1>:f r2.0<0;1,0>:ud {I@1} // ALU pipe: float; $70 (W) math.inv (1|M0) r2.2<1>:f r1.8<0;1,0>:f {F@1} // ALU pipe: math; $71 (W) mov (1|M0) r1.4<1>:df r2.0<0;1,0>:ud {M@1} // ALU pipe: long; $72 (W) mov (1|M0) r1.5<1>:df r2.2<0;1,0>:f // ALU pipe: long; $74 (W) mov (1|M0) r2.0<1>:df -r1.4<0;1,0>:df {L@2} // ALU pipe: long; $73 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $75 (W) mov (1|M0) r1.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $76 (W) mad (1|M0) r2.1<1>:df r1.4<0;0>:df r1.5<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $76 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $77 (W) mov (1|M0) r5.2<1>:df r2.4<0;1,0>:ud {A@1,$3.dst} // ALU pipe: long; $78 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $79 (W) mul (1|M0) r2.0<1>:df r1.5<0;1,0>:df r5.2<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $80 (W) mad (1|M0) r5.2<1>:df r2.0<0;0>:df r2.1<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $81 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $82 (W) mov (1|M0) r2.2<2>:ud r5.2<0;1,0>:df {A@1} // ALU pipe: int; $83 (W) mov (1|M0) r2.0<1>:d (abs)r1.12<0;1,0>:d // ALU pipe: int; $84 (W) mul (1|M0) acc0.0<1>:d r2.0<0;1,0>:d r2.4<0;1,0>:uw {Compacted,I@1} // ALU pipe: int; $84 (W) macl (1|M0) r2.0<1>:d r2.0<0;1,0>:d r2.2<0;1,0>:d {Compacted} // ALU pipe: int; $85 (W) add3 (1|M0) r1.8<1>:d r2.4<0;0>:d r1.15<0;0>:d -r2.0<0>:d {I@1} // ALU pipe: int; $85 (W) xor (1|M0) r2.5<1>:d r1.8<0;1,0>:d r1.15<0;1,0>:d {I@1} // ALU pipe: int; $86 // B012: Preds:{B011, B010}, Succs:{B013, B014} _0_034: (W) add (1|M0) r2.5<1>:d r2.6<0;1,0>:d r2.5<0;1,0>:d {I@1} // ALU pipe: int; $88 // Line 71: pid_n = (pid % num_pid_in_group) // group_size_m (W&~f2.1) jmpi _0_035 // ALU pipe: int; $90 // B013: Preds:{B012}, Succs:{B015} _0_036: (W) mov (1|M0) r1.13<1>:d -1:w // ALU pipe: int; $92 (W) jmpi _0_037 // $93 // B014: Preds:{B012}, Succs:{B015} _0_035: (W) asr (1|M0) r5.4<1>:d r1.14<0;1,0>:d 31:w {$3.dst} // ALU pipe: int; $95 (W) asr (1|M0) r1.15<1>:d r0.1<0;1,0>:d 31:w // ALU pipe: int; $96 (W) add (1|M0) r1.8<1>:d r5.4<0;1,0>:d r2.7<0;1,0>:d {I@2} // ALU pipe: int; $97 (W) xor (1|M0) r1.14<1>:d r1.8<0;1,0>:d r5.4<0;1,0>:d {I@1} // ALU pipe: int; $98 (W) add (1|M0) r1.8<1>:d r1.15<0;1,0>:d r0.1<0;1,0>:d // ALU pipe: int; $99 (W) xor (1|M0) r2.4<1>:d r1.8<0;1,0>:d r1.15<0;1,0>:d {I@1} // ALU pipe: int; $100 (W) mov (1|M0) r1.8<1>:f r1.14<0;1,0>:ud {I@1} // ALU pipe: float; $101 (W) math.inv (1|M0) r2.2<1>:f r1.8<0;1,0>:f {F@1} // ALU pipe: math; $102 (W) mov (1|M0) r1.4<1>:df r1.14<0;1,0>:ud {M@1} // ALU pipe: long; $103 (W) mov (1|M0) r1.5<1>:df r2.2<0;1,0>:f // ALU pipe: long; $105 (W) mov (1|M0) r2.0<1>:df -r1.4<0;1,0>:df {L@2} // ALU pipe: long; $104 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $106 (W) mov (1|M0) r1.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $107 (W) mad (1|M0) r2.1<1>:df r1.4<0;0>:df r1.5<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $107 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $108 (W) mov (1|M0) r5.2<1>:df r2.4<0;1,0>:ud {A@1} // ALU pipe: long; $109 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $110 (W) mul (1|M0) r2.0<1>:df r1.5<0;1,0>:df r5.2<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $111 (W) mad (1|M0) r5.2<1>:df r2.0<0;0>:df r2.1<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $112 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $113 (W) mov (1|M0) r5.4<2>:ud r5.2<0;1,0>:df {A@1} // ALU pipe: int; $114 (W) mul (1|M0) acc0.0<1>:d r1.14<0;1,0>:d r5.8<0;1,0>:uw {I@1} // ALU pipe: int; $115 (W) macl (1|M0) r2.0<1>:d r1.14<0;1,0>:d r5.4<0;1,0>:d // ALU pipe: int; $116 (W) add3 (1|M0) r1.8<1>:d r2.4<0;0>:d r1.15<0;0>:d -r2.0<0>:d {I@1} // ALU pipe: int; $116 (W) xor (1|M0) r1.13<1>:d r1.8<0;1,0>:d r1.15<0;1,0>:d {I@1} // ALU pipe: int; $117 // B015: Preds:{B014, B013}, Succs:{B016, B017} _0_037: (W&~f2.0) jmpi _0_038 // ALU pipe: int; $119 // B016: Preds:{B015}, Succs:{B018} _0_039: (W) mov (1|M0) r117.2<1>:d -128:w {Compacted} // ALU pipe: int; $121 (W) jmpi _0_040 // $122 // B017: Preds:{B015}, Succs:{B018} _0_038: (W) asr (2|M0) r2.0<1>:d r1.12<1;1,0>:d 31:w {Compacted,I@4} // ALU pipe: int; $124 (W) add (1|M0) r1.8<1>:d r2.0<0;1,0>:d r1.12<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $126 (W) xor (1|M0) r1.9<1>:d r1.8<0;1,0>:d r2.0<0;1,0>:d {I@1} // ALU pipe: int; $127 (W) add (1|M0) r1.8<1>:d r2.1<0;1,0>:d r1.13<0;1,0>:d // ALU pipe: int; $128 (W) xor (1|M0) r1.14<1>:d r1.8<0;1,0>:d r2.1<0;1,0>:d {I@1} // ALU pipe: int; $129 (W) mov (1|M0) r1.8<1>:f r1.9<0;1,0>:ud {I@1} // ALU pipe: float; $130 (W) math.inv (1|M0) r1.10<1>:f r1.8<0;1,0>:f {F@1} // ALU pipe: math; $131 (W) mov (1|M0) r1.4<1>:df r1.9<0;1,0>:ud {M@1} // ALU pipe: long; $132 (W) mov (1|M0) r1.5<1>:df r1.10<0;1,0>:f // ALU pipe: long; $134 (W) mov (1|M0) r2.1<1>:df -r1.4<0;1,0>:df {L@2} // ALU pipe: long; $133 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $135 (W) mov (1|M0) r1.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $136 (W) mad (1|M0) r1.6<1>:df r1.4<0;0>:df r1.5<0;0>:df r2.1<0>:df {L@1} // ALU pipe: long; $136 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $137 (W) mov (1|M0) r1.4<1>:df r1.14<0;1,0>:ud {A@1} // ALU pipe: long; $138 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $139 (W) mul (1|M0) r2.1<1>:df r1.5<0;1,0>:df r1.4<0;1,0>:df {A@1} // ALU pipe: long; $140 (W) mad (1|M0) r1.4<1>:df r2.1<0;0>:df r1.6<0;0>:df r2.1<0>:df {L@1} // ALU pipe: long; $141 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $142 (W) mov (1|M0) r1.10<2>:ud r1.4<0;1,0>:df {A@1} // ALU pipe: int; $143 (W) xor (1|M0) r1.8<1>:d r2.0<0;1,0>:d r2.1<0;1,0>:d {Compacted} // ALU pipe: int; $144 (W) add (1|M0) r1.8<1>:d r1.8<0;1,0>:d r1.10<0;1,0>:d {I@1} // ALU pipe: int; $145 (W) bfn.(s0^s1^s2) (1|M0) r1.8<1>:ud r1.8<0;0>:ud r2.0<0;0>:ud r2.1<0>:ud {I@1} // ALU pipe: int; $146 (W) shl (1|M0) r117.2<1>:d r1.8<0;1,0>:d 7:w {I@1} // ALU pipe: int; $147 // B018: Preds:{B017, B016}, Succs:{B019, B020} _0_040: // Line 82: a = tl.load(a_tile_ptr) mov (16|M0) r9.0<1>:d r1.0<1;1,0>:uw // ALU pipe: int; $152 // Line 92: tl.store(c_block_ptr, accumulator) (W) mov (1|M0) r2.0<1>:d 48:w {Compacted} // ALU pipe: int; $155 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W) cmp (16|M0) (gt)f0.0 null<1>:d r5.0<0;1,0>:d 0:w {$3.dst} // ALU pipe: int; $163 // Line 72: block_offset_m = pid_m * BLOCK_SIZE_M (W) shl (1|M0) r6.0<1>:d r2.5<0;1,0>:d 7:w // ALU pipe: int; $150 // Line 92: tl.store(c_block_ptr, accumulator) (W) mov (1|M0) r5.4<1>:d 120:w {Compacted} // ALU pipe: int; $158 and (16|M0) r119.0<1>:d r9.0<1;1,0>:d 48:w {Compacted,I@5} // ALU pipe: int; $154 bfn.(s0&s1|s2) (16|M0) r2.0<1>:ud r9.0<1;0>:ud r2.0<0;0>:ud r117.2<0>:ud {I@5} // ALU pipe: int; $156 shr (16|M0) r9.0<1>:d r9.0<1;1,0>:ud 3:w // ALU pipe: int; $157 bfn.(s0&s1|s2) (16|M0) r1.0<1>:ud r9.0<1;0>:ud r5.4<0;0>:ud r6.0<0>:ud {I@1} // ALU pipe: int; $159 // Line 82: a = tl.load(a_tile_ptr) (W) add (1|M0) r117.1<1>:d r4.14<0;1,0>:d -1:w // ALU pipe: int; $161 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W&f0.0) jmpi _0_041 // ALU pipe: int; $164 // B019: Preds:{B018}, Succs:{B023} _0_042: mov (16|M0) r118.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $166 mov (16|M0) r111.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $167 mov (16|M0) r110.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $168 mov (16|M0) r109.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $169 mov (16|M0) r108.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $170 mov (16|M0) r107.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $171 mov (16|M0) r106.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $172 mov (16|M0) r105.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $173 mov (16|M0) r48.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $174 mov (16|M0) r116.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $175 mov (16|M0) r115.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $176 mov (16|M0) r114.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $177 mov (16|M0) r113.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $178 mov (16|M0) r112.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $179 mov (16|M0) r46.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $180 mov (16|M0) r3.0<1>:f 0x0:f {Compacted,$1.dst} // ALU pipe: float; $181 (W) jmpi _0_043 // $182 // B020: Preds:{B018}, Succs:{B021} _0_041: // Line 82: a = tl.load(a_tile_ptr) (W) shl (1|M0) r5.1<1>:d r5.1<0;1,0>:d 1:w // ALU pipe: int; $187 (W) shl (1|M0) r5.4<1>:d r5.0<0;1,0>:d 1:w {Compacted} // ALU pipe: int; $185 (W) add (1|M0) r5.6<1>:d r5.1<0;1,0>:d -1:w {I@2} // ALU pipe: int; $188 // Line 83: b = tl.load(b_tile_ptr) (W) shl (1|M0) r5.1<1>:d r4.15<0;1,0>:d 1:w // ALU pipe: int; $190 or (16|M0) r47.0<1>:d r2.0<1;1,0>:d 64:w {Compacted} // ALU pipe: int; $195 (W) add (1|M0) r5.8<1>:d r5.1<0;1,0>:d -1:w {I@2} // ALU pipe: int; $191 (W) shl (1|M0) r5.1<1>:d r5.2<0;1,0>:d 1:w // ALU pipe: int; $193 // Line 81: for k in range(0, K, BLOCK_SIZE_K): sync.nop null {Compacted,F@1} // $199 mov (16|M0) r3.0<1>:ud 0x0:ud {Compacted,$1.dst} // ALU pipe: int; $199 mov (16|M0) r46.0<1>:ud 0x0:ud {Compacted} // ALU pipe: int; $200 mov (16|M0) r112.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $201 mov (16|M0) r113.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $202 mov (16|M0) r114.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $203 mov (16|M0) r115.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $204 mov (16|M0) r116.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $205 mov (16|M0) r48.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $206 mov (16|M0) r105.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $207 mov (16|M0) r106.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $208 mov (16|M0) r107.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $209 mov (16|M0) r108.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $210 mov (16|M0) r109.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $211 mov (16|M0) r110.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $212 mov (16|M0) r111.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $213 mov (16|M0) r118.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $214 // Line 83: b = tl.load(b_tile_ptr) (W) add (1|M0) r5.9<1>:d r5.0<0;1,0>:d -1:w // ALU pipe: int; $192 // Line 82: a = tl.load(a_tile_ptr) (W) add (1|M0) r117.0<1>:d r5.4<0;1,0>:d -1:w {Compacted} // ALU pipe: int; $186 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W) mov (1|M0) r5.2<1>:d 0:w {Compacted} // ALU pipe: int; $215 // Line 83: b = tl.load(b_tile_ptr) (W) add (1|M0) r5.1<1>:d r5.1<0;1,0>:d -1:w {I@6} // ALU pipe: int; $194 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W) mov (2|M0) r5.4<1>:d 0:w {Compacted} // ALU pipe: int; $197 // B021: Preds:{B022, B020}, Succs:{B022, B023} _0_044: // Line 82: a = tl.load(a_tile_ptr) (W) mov (1|M0) r6.0<1>:uq r4.4<0;1,0>:q // ALU pipe: int; $218 (W) mov (2|M0) r6.2<1>:ud r117.0<1;1,0>:d {I@5} // ALU pipe: int; $218 (W) mov (1|M0) r6.4<1>:ud r5.6<0;1,0>:d // ALU pipe: int; $218 (W) mov (1|M0) r6.5<1>:f r5.5<0;1,0>:f {I@4} // ALU pipe: float; $218 (W) mov (1|M0) r6.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $218 (W) mov (1|M0) r6.7<1>:ud 0x70F:uw // ALU pipe: int; $218 (W) or (1|M0) r5.7<1>:d r5.5<0;1,0>:d 16:w // ALU pipe: int; $219 load_block2d.ugm.d16.a64 (16|M0) r61:4 [r6:1] {A@1,$4} // ex_desc:0x0; desc:0x2400203 // $218 (W) mov (1|M0) r6.0<1>:uq r4.4<0;1,0>:q {$4.src} // ALU pipe: int; $220 (W) mov (1|M0) r6.2<1>:ud r117.0<0;1,0>:d // ALU pipe: int; $220 (W) mov (1|M0) r6.3<1>:ud r117.1<0;1,0>:d // ALU pipe: int; $220 (W) mov (1|M0) r6.4<1>:ud r5.6<0;1,0>:d // ALU pipe: int; $220 (W) mov (1|M0) r6.5<1>:ud r5.7<0;1,0>:d {I@5} // ALU pipe: int; $220 (W) mov (1|M0) r6.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $220 (W) mov (1|M0) r6.7<1>:ud 0x70F:uw // ALU pipe: int; $220 (W) or (1|M0) r5.7<1>:d r5.5<0;1,0>:d 32:w // ALU pipe: int; $221 load_block2d.ugm.d16.a64 (16|M0) r53:4 [r6:1] {A@2,$5} // ex_desc:0x0; desc:0x2400203 // $220 (W) mov (1|M0) r6.0<1>:uq r4.4<0;1,0>:q {$5.src} // ALU pipe: int; $222 (W) mov (1|M0) r6.2<1>:ud r117.0<0;1,0>:d // ALU pipe: int; $222 (W) mov (1|M0) r6.3<1>:ud r117.1<0;1,0>:d // ALU pipe: int; $222 (W) mov (1|M0) r6.4<1>:ud r5.6<0;1,0>:d // ALU pipe: int; $222 (W) mov (1|M0) r6.5<1>:ud r5.7<0;1,0>:d {I@5} // ALU pipe: int; $222 (W) mov (1|M0) r6.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $222 (W) mov (1|M0) r6.7<1>:ud 0x70F:uw // ALU pipe: int; $222 (W) or (1|M0) r5.7<1>:d r5.5<0;1,0>:d 48:w // ALU pipe: int; $223 load_block2d.ugm.d16.a64 (16|M0) r57:4 [r6:1] {A@2,$6} // ex_desc:0x0; desc:0x2400203 // $222 (W) mov (1|M0) r6.0<1>:uq r4.4<0;1,0>:q {$6.src} // ALU pipe: int; $224 (W) mov (1|M0) r6.2<1>:ud r117.0<0;1,0>:d // ALU pipe: int; $224 (W) mov (1|M0) r6.3<1>:ud r117.1<0;1,0>:d // ALU pipe: int; $224 (W) mov (1|M0) r6.4<1>:ud r5.6<0;1,0>:d // ALU pipe: int; $224 (W) mov (1|M0) r6.5<1>:ud r5.7<0;1,0>:d {I@5} // ALU pipe: int; $224 (W) mov (1|M0) r6.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $224 (W) mov (1|M0) r6.7<1>:ud 0x70F:uw // ALU pipe: int; $224 // Line 83: b = tl.load(b_tile_ptr) (W) or (1|M0) r5.7<1>:d r5.4<0;1,0>:d 16:w // ALU pipe: int; $227 // Line 82: a = tl.load(a_tile_ptr) load_block2d.ugm.d16.a64 (16|M0) r49:4 [r6:1] {A@2,$7} // ex_desc:0x0; desc:0x2400203 // $224 // Line 83: b = tl.load(b_tile_ptr) (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$7.src} // ALU pipe: int; $226 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $226 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $226 (W) mov (1|M0) r6.5<1>:ud r2.0<0;1,0>:d // ALU pipe: int; $226 (W) mov (1|M0) r6.6<1>:ud r5.4<0;1,0>:d // ALU pipe: int; $226 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $226 (W) or (1|M0) r5.10<1>:d r5.4<0;1,0>:d 32:w // ALU pipe: int; $229 load_block2d.ugm.d16v.a64 (16|M0) r89:8 [r6:1] {A@2,$8} // ex_desc:0x0; desc:0x2800283 // $226 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$8.src} // ALU pipe: int; $228 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $228 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $228 (W) mov (1|M0) r6.5<1>:ud r2.0<0;1,0>:d // ALU pipe: int; $228 (W) mov (1|M0) r6.6<1>:ud r5.7<0;1,0>:d // ALU pipe: int; $228 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $228 (W) or (1|M0) r5.11<1>:d r5.4<0;1,0>:d 48:w // ALU pipe: int; $231 load_block2d.ugm.d16v.a64 (16|M0) r97:8 [r6:1] {A@2,$9} // ex_desc:0x0; desc:0x2800283 // $228 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$9.src} // ALU pipe: int; $230 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $230 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $230 (W) mov (1|M0) r6.5<1>:ud r2.0<0;1,0>:d // ALU pipe: int; $230 (W) mov (1|M0) r6.6<1>:ud r5.10<0;1,0>:d // ALU pipe: int; $230 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $230 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r7.0<1>:f r111.0<1;1,0>:f {Compacted} // ALU pipe: float; $239 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r81:8 [r6:1] {A@1,$10} // ex_desc:0x0; desc:0x2800283 // $230 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$10.src} // ALU pipe: int; $232 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $232 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $232 (W) mov (1|M0) r6.5<1>:ud r2.0<0;1,0>:d // ALU pipe: int; $232 (W) mov (1|M0) r6.6<1>:ud r5.11<0;1,0>:d // ALU pipe: int; $232 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $232 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r8.0<1>:f r110.0<1;1,0>:f {Compacted} // ALU pipe: float; $240 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r73:8 [r6:1] {A@1,$11} // ex_desc:0x0; desc:0x2800283 // $232 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$11.src} // ALU pipe: int; $233 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $233 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $233 (W) mov (1|M0) r6.5<1>:ud r47.0<0;1,0>:d // ALU pipe: int; $233 (W) mov (1|M0) r6.6<1>:ud r5.4<0;1,0>:d // ALU pipe: int; $233 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $233 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r9.0<1>:f r109.0<1;1,0>:f {Compacted} // ALU pipe: float; $241 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r65:8 [r6:1] {A@1,$12} // ex_desc:0x0; desc:0x2800283 // $233 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$12.src} // ALU pipe: int; $234 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $234 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $234 (W) mov (1|M0) r6.5<1>:ud r47.0<0;1,0>:d // ALU pipe: int; $234 (W) mov (1|M0) r6.6<1>:ud r5.7<0;1,0>:d // ALU pipe: int; $234 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $234 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r10.0<1>:f r108.0<1;1,0>:f {Compacted} // ALU pipe: float; $242 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r30:8 [r6:1] {A@1,$13} // ex_desc:0x0; desc:0x2800283 // $234 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$13.src} // ALU pipe: int; $235 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $235 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $235 (W) mov (1|M0) r6.5<1>:ud r47.0<0;1,0>:d // ALU pipe: int; $235 (W) mov (1|M0) r6.6<1>:ud r5.10<0;1,0>:d // ALU pipe: int; $235 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $235 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r11.0<1>:f r107.0<1;1,0>:f {Compacted} // ALU pipe: float; $243 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r22:8 [r6:1] {A@1,$14} // ex_desc:0x0; desc:0x2800283 // $235 (W) mov (1|M0) r6.0<1>:uq r4.5<0;1,0>:q {$14.src} // ALU pipe: int; $236 (W) mov (2|M0) r6.2<1>:ud r5.8<1;1,0>:d // ALU pipe: int; $236 (W) mov (1|M0) r6.4<1>:ud r5.1<0;1,0>:d // ALU pipe: int; $236 (W) mov (1|M0) r6.5<1>:ud r47.0<0;1,0>:d // ALU pipe: int; $236 (W) mov (1|M0) r6.6<1>:ud r5.11<0;1,0>:d // ALU pipe: int; $236 (W) mov (1|M0) r6.7<1>:ud 0xF0F:uw // ALU pipe: int; $236 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r12.0<1>:f r106.0<1;1,0>:f {Compacted} // ALU pipe: float; $244 // Line 83: b = tl.load(b_tile_ptr) load_block2d.ugm.d16v.a64 (16|M0) r14:8 [r6:1] {A@1,$15} // ex_desc:0x0; desc:0x2800283 // $236 // Line 84: accumulator += tl.dot(a, b) mov (16|M0) r13.0<1>:f r105.0<1;1,0>:f {Compacted} // ALU pipe: float; $245 mov (16|M0) r38.0<1>:f r48.0<1;1,0>:f {Compacted} // ALU pipe: float; $246 mov (16|M0) r39.0<1>:f r116.0<1;1,0>:f {Compacted} // ALU pipe: float; $247 mov (16|M0) r40.0<1>:f r115.0<1;1,0>:f {Compacted} // ALU pipe: float; $248 mov (16|M0) r41.0<1>:f r114.0<1;1,0>:f {Compacted} // ALU pipe: float; $249 mov (16|M0) r42.0<1>:f r113.0<1;1,0>:f {Compacted} // ALU pipe: float; $250 mov (16|M0) r43.0<1>:f r112.0<1;1,0>:f {Compacted} // ALU pipe: float; $251 mov (16|M0) r44.0<1>:f r46.0<1;1,0>:f {Compacted} // ALU pipe: float; $252 mov (16|M0) r45.0<1>:f r3.0<1;1,0>:f {Compacted} // ALU pipe: float; $253 mov (16|M0) r6.0<1>:f r118.0<1;1,0>:f {Compacted,$15.src} // ALU pipe: float; $238 sync.nop null {Compacted,F@1} // $255 sync.allwr ($8,$12) // $255 dpas.8x8 (16|M0) r38:f r38:f r65:hf r61.0:hf {Atomic,Compacted,$4.dst} // $255 dpas.8x8 (16|M0) r6:f r6:f r89:hf r61.0:hf {Compacted,$4} // $254 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W) add (1|M0) r5.2<1>:d r5.2<0;1,0>:d 64:w {Compacted} // ALU pipe: int; $279 (W) cmp (16|M0) (lt)f1.1 null<1>:d r5.2<0;1,0>:d r5.0<0;1,0>:d {I@1} // ALU pipe: int; $280 // Line 84: accumulator += tl.dot(a, b) sync.allwr ($4,$9,$13) // $257 dpas.8x8 (16|M0) r38:f r38:f r30:hf r53.0:hf {Atomic,Compacted,$5.dst} // $257 dpas.8x8 (16|M0) r6:f r6:f r97:hf r53.0:hf {Compacted,$5} // $256 sync.allwr ($5,$10,$14) // $258 dpas.8x8 (16|M0) r6:f r6:f r81:hf r57.0:hf {Atomic,Compacted,$6.dst} // $258 dpas.8x8 (16|M0) r38:f r38:f r22:hf r57.0:hf {Compacted,$6} // $259 sync.allwr ($6,$11,$15) // $260 dpas.8x8 (16|M0) r6:f r6:f r73:hf r49.0:hf {Atomic,Compacted,$7.dst} // $260 dpas.8x8 (16|M0) r38:f r38:f r14:hf r49.0:hf {Compacted,$7} // $269 mov (16|M0) r118.0<1>:f r6.0<1;1,0>:f {Compacted,$7.dst} // ALU pipe: float; $261 mov (16|M0) r111.0<1>:f r7.0<1;1,0>:f {Compacted} // ALU pipe: float; $262 mov (16|M0) r110.0<1>:f r8.0<1;1,0>:f {Compacted} // ALU pipe: float; $263 mov (16|M0) r109.0<1>:f r9.0<1;1,0>:f {Compacted} // ALU pipe: float; $264 mov (16|M0) r108.0<1>:f r10.0<1;1,0>:f {Compacted} // ALU pipe: float; $265 mov (16|M0) r107.0<1>:f r11.0<1;1,0>:f {Compacted} // ALU pipe: float; $266 mov (16|M0) r106.0<1>:f r12.0<1;1,0>:f {Compacted} // ALU pipe: float; $267 mov (16|M0) r105.0<1>:f r13.0<1;1,0>:f {Compacted} // ALU pipe: float; $268 mov (16|M0) r48.0<1>:f r38.0<1;1,0>:f {Compacted} // ALU pipe: float; $270 mov (16|M0) r116.0<1>:f r39.0<1;1,0>:f {Compacted} // ALU pipe: float; $271 mov (16|M0) r115.0<1>:f r40.0<1;1,0>:f {Compacted} // ALU pipe: float; $272 mov (16|M0) r114.0<1>:f r41.0<1;1,0>:f {Compacted} // ALU pipe: float; $273 mov (16|M0) r113.0<1>:f r42.0<1;1,0>:f {Compacted} // ALU pipe: float; $274 mov (16|M0) r112.0<1>:f r43.0<1;1,0>:f {Compacted} // ALU pipe: float; $275 mov (16|M0) r46.0<1>:f r44.0<1;1,0>:f {Compacted} // ALU pipe: float; $276 mov (16|M0) r3.0<1>:f r45.0<1;1,0>:f {Compacted} // ALU pipe: float; $277 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W&~f1.1) jmpi _0_043 // ALU pipe: int; $281 // B022: Preds:{B021}, Succs:{B021} _0_045: // Line 85: a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) (W) add (1|M0) r5.5<1>:d r5.5<0;1,0>:d 64:w // ALU pipe: int; $284 // Line 86: b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) (W) add (1|M0) r5.4<1>:d r5.4<0;1,0>:d 64:w {Compacted} // ALU pipe: int; $286 // Line 81: for k in range(0, K, BLOCK_SIZE_K): (W) jmpi _0_044 // $288 // B023: Preds:{B021, B019}, Succs:{} _0_043: // Line 92: tl.store(c_block_ptr, accumulator) (W) shl (1|M0) r4.0<1>:d r4.15<0;1,0>:d 2:w // ALU pipe: int; $291 mov (16|M0) r13.0<1>:d 0:w {Compacted,F@7} // ALU pipe: int; $303 (W) add (1|M0) r4.1<1>:d r4.0<0;1,0>:d -1:w {Compacted,I@2} // ALU pipe: int; $292 (W) shl (1|M0) r4.0<1>:d r5.3<0;1,0>:d 2:w {Compacted} // ALU pipe: int; $293 mov (16|M0) r5.0<1>:f r118.0<1;1,0>:f {Compacted,I@1} // ALU pipe: float; $295 (W) mov (1|M0) r5.5<1>:ud r2.0<0;1,0>:d {F@1} // ALU pipe: int; $304 add (16|M0) acc0.0<1>:d r119.0<1;1,0>:d 32:w {Compacted} // ALU pipe: int; $305 (W) add (1|M0) r4.0<1>:d r4.0<0;1,0>:d -1:w {Compacted} // ALU pipe: int; $294 or (16|M0) r14.0<1>:d acc0.0<1;1,0>:d r117.2<0;1,0>:d // ALU pipe: int; $306 (W) mov (1|M0) r5.0<1>:uq r4.6<0;1,0>:q // ALU pipe: int; $304 (W) mov (1|M0) r5.2<1>:ud r4.1<0;1,0>:d // ALU pipe: int; $304 (W) mov (1|M0) r5.3<1>:ud r117.1<0;1,0>:d // ALU pipe: int; $304 (W) mov (1|M0) r5.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $304 (W) mov (1|M0) r5.7<1>:ud 0x70F:uw // ALU pipe: int; $304 (W) mov (1|M0) r5.4<1>:ud r4.0<0;1,0>:d {I@7} // ALU pipe: int; $304 (W) mov (1|M0) r2.0<1>:uq r4.6<0;1,0>:q // ALU pipe: int; $316 (W) mov (1|M0) r2.2<1>:ud r4.1<0;1,0>:d // ALU pipe: int; $316 (W) mov (1|M0) r2.3<1>:ud r117.1<0;1,0>:d // ALU pipe: int; $316 (W) mov (1|M0) r2.4<1>:ud r4.0<0;1,0>:d // ALU pipe: int; $316 (W) mov (1|M0) r2.5<1>:ud r14.0<0;1,0>:d {I@7} // ALU pipe: int; $316 (W) mov (1|M0) r2.6<1>:ud r1.0<0;1,0>:d // ALU pipe: int; $316 (W) mov (1|M0) r2.7<1>:ud 0x70F:uw // ALU pipe: int; $316 mov (16|M0) r10.0<1>:f r107.0<1;1,0>:f {Compacted} // ALU pipe: float; $300 mov (16|M0) r10.0<1>:f r112.0<1;1,0>:f {Compacted} // ALU pipe: float; $312 (W) mov (16|M0) r112.0<1>:f r0.0<1;1,0>:f {Compacted} // ALU pipe: float; $317 store_block2d.ugm.d32.a64 (16|M0) [r5:1] r13:8 {I@7,$0} // ex_desc:0x0; desc:0x2000407 // $304 store_block2d.ugm.d32.a64 (16|M0) [r2:1] r13:8 {A@1,$1} // ex_desc:0x0; desc:0x2000407 // $316 mov (16|M0) r6.0<1>:f r111.0<1;1,0>:f {Compacted} // ALU pipe: float; $296 mov (16|M0) r7.0<1>:f r110.0<1;1,0>:f {Compacted} // ALU pipe: float; $297 mov (16|M0) r8.0<1>:f r109.0<1;1,0>:f {Compacted} // ALU pipe: float; $298 mov (16|M0) r9.0<1>:f r108.0<1;1,0>:f {Compacted} // ALU pipe: float; $299 mov (16|M0) r11.0<1>:f r106.0<1;1,0>:f {Compacted} // ALU pipe: float; $301 mov (16|M0) r12.0<1>:f r105.0<1;1,0>:f {Compacted} // ALU pipe: float; $302 mov (16|M0) r5.0<1>:f r48.0<1;1,0>:f {Compacted,$0.src} // ALU pipe: float; $307 mov (16|M0) r6.0<1>:f r116.0<1;1,0>:f {Compacted} // ALU pipe: float; $308 mov (16|M0) r7.0<1>:f r115.0<1;1,0>:f {Compacted} // ALU pipe: float; $309 mov (16|M0) r8.0<1>:f r114.0<1;1,0>:f {Compacted} // ALU pipe: float; $310 mov (16|M0) r9.0<1>:f r113.0<1;1,0>:f {Compacted} // ALU pipe: float; $311 mov (16|M0) r11.0<1>:f r46.0<1;1,0>:f {Compacted} // ALU pipe: float; $313 mov (16|M0) r12.0<1>:f r3.0<1;1,0>:f {Compacted} // ALU pipe: float; $314 (W) send.gtwy (1|M0) null r112 null:0 0x0 0x02000010 {EOT} // wr:1+0, rd:0; end of thread // $317 L5136: nop // $317 //.BankConflicts: 0 //.ByteRMWs: 0 // //.numALUInst: 345 //.accSubDef: 1 //.accSubUse: 1 //.accSubCandidateDef: 1 //.accSubCandidateUse: 1 // // //.singlePipeAtOneDistNum: 60 //.allAtOneDistNum: 44 //.syncInstCount: 4 //.tokenReuseCount: 0 //.AfterWriteTokenDepCount: 23 //.AfterReadTokenDepCount: 15 ```

From the example kernel, we can see a lot of redundant move instructions in the kernel.

  1. Some are used to compose the send message repeatedly.
  2. Some are because of the insertelement and extractelement pair which maybe optimized.
chengjunlu commented 7 months ago

I used a new the PoC branch chengjun/llvm-target-dpas-ver2.

The 08-experimental-block-pointer.py is used to test the performance.

matmul-performance:
        M       N       K      oneDNN  Triton native  Triton 2D load  Triton 2D load/store
0  2048.0   512.0   512.0    3.140586       2.251800        2.832453              2.855802
1  2048.0  1024.0  1024.0   11.766426       5.636545        9.421757              9.648848
2  2048.0  2048.0  2048.0   35.201560       8.855548       23.471529             24.476085
3  2048.0  4096.0  4096.0   91.632610      10.826774       36.535730             37.672249
4  4096.0  4096.0  4096.0  130.627861      11.654737       41.180180             42.552650
5  2048.0  8192.0  8192.0  160.887734      11.540294       41.483934             42.576222
6  4096.0  8192.0  8192.0  175.409304       8.006108       43.217000             44.266945
7  8192.0  8192.0  8192.0  173.861867       6.047589       44.402908             45.352001

The performance of oneDNN upbound is ~175. The SIMT Triton kernel is ~45. Only 25% performance to oneDNN.

etiotto commented 7 months ago

for the 08-experimental-block-pointer.py with gemm size 4k4k4k*fp16, with triton lowering to spirv(vc-intrinsics), the max perf is as below: 295tflops with all the optimizations: 2dload, 2dstore, 2dprefetch, etc (MLIR team 300 tflops, XeTLA 310tflops) 150tflops without prefetch

the machine info is PVC (Max 1550) EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64

Hi @Dewei-Wang-sh @chengjunlu can you guys measure your respective experiments on the same GPU please. So it is easy to compare. We need:

Dewei-Wang-sh commented 7 months ago

on PVC Max 1550 EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64 oneDNN : 195tflops performance for SIMD approach with (DPAS + 2D block load + 2D block store) : 150tflops performance for SIMD approach with (DPAS + 2D block load + 2D block store + prefetch) : 290tflops

chengjunlu commented 7 months ago

The SIMT Triton GEMM performance get ~85% of oneDNN kernel in 4K 4K 4K case.

matmul-performance:
        M       N       K      oneDNN  Triton 2D load/store
0  2048.0   512.0   512.0    3.129673              3.024580
1  2048.0  1024.0  1024.0   11.766426             10.957663
2  2048.0  2048.0  2048.0   35.496352             35.339673
3  2048.0  4096.0  4096.0   89.707556             79.336740
4  4096.0  4096.0  4096.0  126.444561            107.118973
5  2048.0  8192.0  8192.0  149.954023            108.694400
6  4096.0  8192.0  8192.0  176.490089            122.644700
7  8192.0  8192.0  8192.0  173.973367            129.720290

The code branches: GENX LLVM. https://github.com/chengjunlu/llvm/tree/chengjun/genx Triton. https://github.com/intel/intel-xpu-backend-for-triton/tree/chengjun/llvm-target-dpas-ver2

Things need to do :

  1. The correctness issue reported in this task couldn't be reproduced by the IGC we used for the best performance. Hope the IGC could fix the issue in future release. I will report this issue to IGC with a JIRA ticked when we meet it again in the main branch.
  2. We uses this signature in the 2D load/store in the PoC. Request IGC to give the interface in OpenCL correspondingly at least.
  %88 = call <64 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64 %52, i32 %35, i32 %36, i32 %38, i32 %75, i32 %51, i32 16, i32 16, i32 32, i32 2, i1 false, i1 false, i32 0) #3, !dbg !483
  %89 = call <32 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64 %55, i32 %46, i32 %47, i32 %49, i32 %54, i32 %76, i32 16, i32 16, i32 32, i32 2, i1 false, i1 true, i32 0) #3, !dbg !486