Closed chengjunlu closed 5 months ago
We used warp=16 by default. Will try warp=8, 32 etc. in the future when we do performance tuning.
The issue cannot be reproduced on the latest Triton XPU llvm-target branch with the rolling IGC driver.
Now we can get correct result of the GEMM kernel with the dpas instruction with threads_per_warp=32
.
For the dpas instruction, we just uses half number of scalar to the threads_per_warp=16
as the operands type to A and B and the return value.
%260 = call <4 x float> @llvm.genx.GenISA.sub.group.dpas.v4f32.v4f32.v4i16.v4i32(<4 x float> %255, <4 x i16> %234, <4 x i32> %251, i32 10, i32 10, i32 8, i32 8, i1 false) #4, !dbg !418
The kernel size of the threads_per_warp=32
:
-rw-rw-r-- 1 jovyan users 15332 Mar 27 00:46 _kernel.spv
The kernel size of the threads_per_warp=16
:
-rw-rw-r-- 1 jovyan users 22624 Mar 27 00:48 _kernel.spv
Here is the LLVM IR and GenISA assembly for the Triton GEMM kernel when threads_per_warp=32.
@chengjunlu Please confirm if the provided OCL builtins are enough to have DPAS working with 32 threads per warp.
@chengjunlu Please confirm if the provided OCL builtins are enough to have DPAS working with 32 threads per warp.
I don't think there is such API in OCL for now. I will create a new issue to track the OCL interface for the sub-group-size=32 DPAS.
For this issue, I think we can close it as the GenISA DPAS works as expected with sub-group-size=32.
The GEMM kernel couldn't output the correct result with the DPAS op while the threads_per_warp is 32.
The threads_per_warp 16 works properly for both the ATSM and PVC. But the threads_per_warp 32 doesn't work as expected.
Need to debug.