Closed whitneywhtsang closed 7 months ago
To generate a Triton GEMM kernel with 2D load and DPAS has been finished. The code branches: GENX LLVM. https://github.com/chengjunlu/llvm/tree/chengjun/genx Triton. https://github.com/intel/intel-xpu-backend-for-triton/tree/chengjun/llvm-target-dpas
The performance is not good.
matmul-performance:
M N K TRANS_A TRANS_B oneDNN Triton
0 2048 512 512 False False 3.076229 2.619895
1 2048 1024 1024 False False 11.766426 7.293279
2 2048 2048 2048 False False 38.950051 12.688430
3 2048 4096 4096 False False 112.458204 15.564876
I will add the prefetch ops and test it again.
@chengjunlu have you used 2D stores or only 2D loads in this experiment?
@chengjunlu have you used 2D stores or only 2D loads in this experiment?
Only 2D loads for now. I will continue to enable the 2D store and 2D prefetching.
I have enabled the 2D prefetch op. But it seems the final GEN binary doesn't have the 2D prefetching ops as expected.
There are prefetching GenISA in the optimized LLVM IR for the scalar compiler to gen the code.
// Prefetching in the pre-epilogue.
%50 = ptrtoint half addrspace(1)* %43 to i64, !dbg !392
call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %50, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
%51 = sext i32 %26 to i64, !dbg !393
%52 = getelementptr half, half addrspace(1)* %1, i64 %51, !dbg !393
%53 = shl i32 %4, 1, !dbg !393
%54 = add i32 %53, -1, !dbg !393
%55 = add i32 %5, -1, !dbg !393
%56 = ptrtoint half addrspace(1)* %52 to i64, !dbg !393
call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %56, i32 %54, i32 %55, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !393
%57 = icmp sgt i32 %5, 0, !dbg !394
br i1 %57, label %.lr.ph, label %._crit_edge38, !dbg !394
._crit_edge38: ; preds = %cond-add-join41
%.pre = and i32 %27, 56, !dbg !395
br label %121, !dbg !394
...
// Prefetching in the loop body.
%79 = getelementptr half, half addrspace(1)* %43, i64 %78, !dbg !392
%80 = ptrtoint half addrspace(1)* %79 to i64, !dbg !392
call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %80, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
%81 = shl i32 %77, 9, !dbg !393
%82 = sext i32 %81 to i64, !dbg !393
%83 = getelementptr half, half addrspace(1)* %52, i64 %82, !dbg !393
%84 = ptrtoint half addrspace(1)* %83 to i64, !dbg !393
call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %84, i32 %54, i32 %55, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !393
%85 = call <8 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64 %60, i32 %45, i32 %46, i32 1023, i32 %66, i32 %59, i32 16, i32 16, i32 8, i32 1, i1 false, i1 false, i32 0) #4, !dbg !392
%86 = or i32 %66, 16, !dbg !392
There is no prefetching ops in the final Gen assemble from the first inst of the tt.load
to the first 2D load.
// Line 86: b = tl.load(b_tile_ptr)
(W) mov (1|M0) r4.0<1>:d 56:w {Compacted} // ALU pipe: int; $191
(W) mov (1|M0) r6.0<1>:d 48:w {Compacted} // ALU pipe: int; $193
// Line 85: a = tl.load(a_tile_ptr)
(W) shl (1|M0) r1.8<1>:d r5.2<0;1,0>:d 1:w // ALU pipe: int; $184
// Line 86: b = tl.load(b_tile_ptr)
(W) shl (1|M0) r2.0<1>:d r5.1<0;1,0>:d 1:w {Compacted} // ALU pipe: int; $188
// Line 84: for k in range(0, K, BLOCK_SIZE_K):
mov (16|M0) r94.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $198
mov (16|M0) r116.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $199
mov (16|M0) r95.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $200
mov (16|M0) r115.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $201
mov (16|M0) r96.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $202
mov (16|M0) r114.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $203
mov (16|M0) r97.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $204
mov (16|M0) r113.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $205
// Line 86: b = tl.load(b_tile_ptr)
bfn.(s0&s1|s2) (16|M0) r93.0<1>:ud r124.0<1;0>:ud r4.0<0;0>:ud r125.1<0>:ud {I@4} // ALU pipe: int; $192
bfn.(s0&s1|s2) (16|M0) r117.0<1>:ud r51.0<1;0>:ud r6.0<0;0>:ud r125.0<0>:ud {I@4} // ALU pipe: int; $194
// Line 85: a = tl.load(a_tile_ptr)
(W) add (1|M0) r126.9<1>:d r5.0<0;1,0>:d -1:w // ALU pipe: int; $186
// Line 86: b = tl.load(b_tile_ptr)
(W) add (1|M0) r50.13<1>:d r5.2<0;1,0>:d -1:w // ALU pipe: int; $190
// Line 84: for k in range(0, K, BLOCK_SIZE_K):
(W) mov (2|M0) r50.8<1>:d 0:w {Compacted} // ALU pipe: int; $196
(W) mov (1|M0) r50.4<1>:d 0:w {Compacted} // ALU pipe: int; $206
// Line 85: a = tl.load(a_tile_ptr)
(W) add (1|M0) r126.8<1>:d r1.8<0;1,0>:d -1:w {I@7} // ALU pipe: int; $185
// Line 86: b = tl.load(b_tile_ptr)
(W) add (1|M0) r50.12<1>:d r2.0<0;1,0>:d -1:w {I@7} // ALU pipe: int; $189
// B021: Preds:{B022, B020}, Succs:{B022, B023}
_0_047:
// Line 85: a = tl.load(a_tile_ptr)
(W) mov (1|M0) r2.0<1>:uq r4.5<0;1,0>:q // ALU pipe: int; $209
(W) mov (2|M0) r2.2<1>:ud r126.8<1;1,0>:d {I@3} // ALU pipe: int; $209
(W) mov (1|M0) r2.4<1>:ud 1023:w // ALU pipe: int; $209
(W) mov (1|M0) r2.5<1>:ud r50.9<0;1,0>:d // ALU pipe: int; $209
(W) mov (1|M0) r2.6<1>:ud r93.0<0;1,0>:d // ALU pipe: int; $209
(W) mov (1|M0) r2.7<1>:ud 0x70F:uw // ALU pipe: int; $209
// Line 86: b = tl.load(b_tile_ptr)
(W) mov (1|M0) r29.0<1>:uq r4.6<0;1,0>:q // ALU pipe: int; $217
(W) mov (2|M0) r29.2<1>:ud r50.12<1;1,0>:d {I@7} // ALU pipe: int; $217
(W) mov (1|M0) r29.4<1>:ud 1023:w // ALU pipe: int; $217
(W) mov (1|M0) r29.5<1>:ud r117.0<0;1,0>:d // ALU pipe: int; $217
(W) mov (1|M0) r29.6<1>:ud r50.8<0;1,0>:d // ALU pipe: int; $217
(W) mov (1|M0) r29.7<1>:ud 0xF0F:uw // ALU pipe: int; $217
// Line 85: a = tl.load(a_tile_ptr)
load_block2d.ugm.d16.a64 (16|M0) r7:4 [r2:1] {I@7,$4} // ex_desc:0x0; desc:0x2400203 // $209
I will debug the issue.
The performance doesn't change because the prefetch ops seems not in the final binary.
matmul-performance:
M N K TRANS_A TRANS_B oneDNN Triton
0 2048 512 512 False False 3.024580 2.580859
1 2048 1024 1024 False False 11.607216 7.217307
2 2048 2048 2048 False False 38.595390 12.376777
3 2048 4096 4096 False False 111.028650 15.548922
The attributes of the prefetching GenISA.
call void @llvm.genx.GenISA.LSC2DBlockPrefetch(i64 %50, i32 %45, i32 %46, i32 1023, i32 %48, i32 %49, i32 32, i32 3, i32 15, i32 1, i1 false, i1 false, i32 4) #4, !dbg !392
attributes #4 = { nounwind }
The GenISA is created in the MLIR LLVM dialect instead of the LLVM IR. I am not sure whether this would cause the probelm.
The prefetching ops are generated as expected after updating the IGC to the new one.
load_block2d.ugm.d32.a64.ca.ca (1|M0) null:0 [r11:1] {I@2,$4} // ex_desc:0x0; desc:0x2080403 // $180
But the new IGC has correctness issue in the Triton kernel. The kernel doesn't return the correct result from the GEMM.
The performance doesn't make sense with the new driver (oneDNN performance decrease a lot.).
matmul-performance:
M N K TRANS_A TRANS_B oneDNN Triton
0 2048 512 512 False False 0.285490 0.323128
1 2048 1024 1024 False False 1.337173 1.167114
2 2048 2048 2048 False False 5.218729 3.638170
3 2048 4096 4096 False False 19.639573 8.156155
After review the disassemble code, there are some points could be enhanced for performance:
The 2D store/load and prefetching ops has been enabled the on the GEMM kernel.
The known points we can do to improve the performance:
The caching improvements:
Use the packed type for the Dot operands layout encoding in TypeConvert to remove the redundant pack/un-pack register movement.
The send message payload of the load/prefetching op seems composed repeatedly. We can improve this. (This is only available in SIMD for now. Need to investigate to require IGC to optimize this on scalar compiler backend.)
The performance for now.
matmul-performance:
M N K cuBLAS Triton
0 2048.0 512.0 512.0 3.095257 2.726150
1 2048.0 1024.0 1024.0 11.607216 9.541525
2 2048.0 2048.0 2048.0 34.085901 23.184554
3 2048.0 4096.0 4096.0 88.441355 34.846204
4 2048.0 8192.0 8192.0 157.030987 32.682423
As discussed, please work with @Dewei-Wang-sh to add a column with Triton (SIMD) measurement, and a column with Triton (SIMD) measurement without prefetch, to your example.
The 2D store/load and prefetching ops has been enabled the on the GEMM kernel.
The known points we can do to improve the performance:
- The caching improvements:
- The IGC driver doesn't generate the prefetching ops in the binary even we add the GenISA Prefetching 2D Ops. Need to wait the IGC driver updating.
- Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.
- Use the packed type for the Dot operands layout encoding in TypeConvert to remove the redundant pack/un-pack register movement.
- The send message payload of the load/prefetching op seems composed repeatedly. We can improve this. (This is only available in SIMD for now. Need to investigate to require IGC to optimize this on scalar compiler backend.)
The performance for now.
matmul-performance: M N K cuBLAS Triton 0 2048.0 512.0 512.0 3.095257 2.726150 1 2048.0 1024.0 1024.0 11.607216 9.541525 2 2048.0 2048.0 2048.0 34.085901 23.184554 3 2048.0 4096.0 4096.0 88.441355 34.846204 4 2048.0 8192.0 8192.0 157.030987 32.682423
I see you are comparing to cuBLAS performance. I think we should compare against oneDNN performance (same platform). I'd like to understand the performance for the following scenarios:
For 4 you will have to wait for IGC to fic the problem you found (hope they are aware of this problem, correct ?)
Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.
Can you explain in more details what you mean by using the "named barrier" and why you think using that operation/instruction would give better performance ?
Add the named barrier to synchronize the execution on different threads to make the data locality better to avoid cache competition.
Can you explain in more details what you mean by using the "named barrier" and why you think using that operation/instruction would give better performance?
The different physical threads are doing the DOT operation on individual GEMM tile parallelly. The threads may be in different stages (Different iterator in the loops) and working on different region of the GEMM. Because we are using the multiple threads to prefetch the data cooperatively and implicitly. If no explicit synchronization, the cache may contain the un-expected data of different stages.
I see you are comparing to cuBLAS performance. I think we should compare against oneDNN performance (same platform). I'd like to understand the performance for the following scenarios:
- just using the DPAS instruction
- using the DPAS instruction + 2D block load
- using the DPAS instruction + 2D block load and 2D block stores
- using the DPAS instruction + 2D block load and 2D block stores + prefetch
For 4 you will have to wait for IGC to fic the problem you found (hope they are aware of this problem, correct ?)
It is the oneDNN performance but the name of the column is cuBLAS by mistake.
Yes. The 4th point depends on the IGC's fix on the prefetching.
for the 08-experimental-block-pointer.py with gemm size 4k*4k*4k*fp16, with triton lowering to spirv(vc-intrinsics), the max perf is as below: 295tflops with all the optimizations: 2dload, 2dstore, 2dprefetch, etc (MLIR team 300 tflops, XeTLA 310tflops) 150tflops without prefetch
the machine info is PVC (Max 1550) EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64
I meet an correctness issue in 2D block store.
The first time I use naive encoding for the tile_height=8
, tile_width=16
and v_blocks=1
as what is used in 2D load encoding.
But I cannot get the correct result on the memory.
genx.matrix.2Dblockstore %arg2, %654, %90, %656, %661, %662, %642 {elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32, vnni_transform = false} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) loc(#loc28)
The second time I use the encoding follows the GFX spec. for the tile_height=7
, tile_width=15
and v_blocks=0
.
But it cause the GPU hang.
genx.matrix.2Dblockstore %arg2, %654, %90, %656, %661, %662, %642 {elem_size_in_bits = 32 : i32, tile_height = 7 : i32, tile_width = 16=5 : i32, transpose = false, v_blocks = 0 : i32, vnni_transform = false} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) loc(#loc28)
I will create a JIRA to track this 2D block store issue with a small reproducer to IGC. I think it should not impact the performance testing. Right?
After aligning with DeWei, we use the same Triton kernel for the performance bench.
The SIMT kernel run with 2D block load/store, DPAS.
matmul-performance:
M N K oneDNN Triton
0 2048.0 512.0 512.0 3.147170 2.781717
1 2048.0 1024.0 1024.0 11.766426 8.677456
2 2048.0 2048.0 2048.0 35.496352 17.951568
3 2048.0 4096.0 4096.0 89.484749 26.475945
4 4096.0 4096.0 4096.0 130.273616 29.551482
5 2048.0 8192.0 8192.0 153.008826 34.834623
The best performance is under this configuration.
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
num_warps=64),
The performance is not good as expectation. Only about 20% to oneDNN on the jupyter hub platform. The expected performance would be 300tfps * 20% = 60tfps on the PVC 1550 for now.
Quickly reviewed the kernel with DeWei. We think there are redundant register movements instructions in the final binary may cause the performance issue. Need further clean up those redundant register movements instructions.
From the example kernel, we can see a lot of redundant move instructions in the kernel.
insertelement
and extractelement
pair which maybe optimized.I used a new the PoC branch chengjun/llvm-target-dpas-ver2
.
The 08-experimental-block-pointer.py
is used to test the performance.
matmul-performance:
M N K oneDNN Triton native Triton 2D load Triton 2D load/store
0 2048.0 512.0 512.0 3.140586 2.251800 2.832453 2.855802
1 2048.0 1024.0 1024.0 11.766426 5.636545 9.421757 9.648848
2 2048.0 2048.0 2048.0 35.201560 8.855548 23.471529 24.476085
3 2048.0 4096.0 4096.0 91.632610 10.826774 36.535730 37.672249
4 4096.0 4096.0 4096.0 130.627861 11.654737 41.180180 42.552650
5 2048.0 8192.0 8192.0 160.887734 11.540294 41.483934 42.576222
6 4096.0 8192.0 8192.0 175.409304 8.006108 43.217000 44.266945
7 8192.0 8192.0 8192.0 173.861867 6.047589 44.402908 45.352001
The performance of oneDNN upbound is ~175. The SIMT Triton kernel is ~45. Only 25% performance to oneDNN.
for the 08-experimental-block-pointer.py with gemm size 4k4k4k*fp16, with triton lowering to spirv(vc-intrinsics), the max perf is as below: 295tflops with all the optimizations: 2dload, 2dstore, 2dprefetch, etc (MLIR team 300 tflops, XeTLA 310tflops) 150tflops without prefetch
the machine info is PVC (Max 1550) EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64
Hi @Dewei-Wang-sh @chengjunlu can you guys measure your respective experiments on the same GPU please. So it is easy to compare. We need:
on PVC Max 1550 EUCount = 512 ThreadCount = 4096 SliceCount = 1 SubSliceCount = 64 oneDNN : 195tflops performance for SIMD approach with (DPAS + 2D block load + 2D block store) : 150tflops performance for SIMD approach with (DPAS + 2D block load + 2D block store + prefetch) : 290tflops
The SIMT Triton GEMM performance get ~85% of oneDNN kernel in 4K 4K 4K case.
matmul-performance:
M N K oneDNN Triton 2D load/store
0 2048.0 512.0 512.0 3.129673 3.024580
1 2048.0 1024.0 1024.0 11.766426 10.957663
2 2048.0 2048.0 2048.0 35.496352 35.339673
3 2048.0 4096.0 4096.0 89.707556 79.336740
4 4096.0 4096.0 4096.0 126.444561 107.118973
5 2048.0 8192.0 8192.0 149.954023 108.694400
6 4096.0 8192.0 8192.0 176.490089 122.644700
7 8192.0 8192.0 8192.0 173.973367 129.720290
The code branches: GENX LLVM. https://github.com/chengjunlu/llvm/tree/chengjun/genx Triton. https://github.com/intel/intel-xpu-backend-for-triton/tree/chengjun/llvm-target-dpas-ver2
Things need to do :
%88 = call <64 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64 %52, i32 %35, i32 %36, i32 %38, i32 %75, i32 %51, i32 16, i32 16, i32 32, i32 2, i1 false, i1 false, i32 0) #3, !dbg !483
%89 = call <32 x i32> @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64 %55, i32 %46, i32 %47, i32 %49, i32 %54, i32 %76, i32 16, i32 16, i32 32, i32 2, i1 false, i1 true, i32 0) #3, !dbg !486
Create a LLVMIR file that uses DPAS, 2d block read/write, and prefetch instructions to get an estimate initial performance at SIMT path.
One could use compiler to generate the LLVM IR file, but not implement the analysis portion and hard code the transformation portion as needed.