Closed chengjunlu closed 7 months ago
The root cause seems in the function getUniqueContigPerThread
.
When the threadsPerWarp = 32, the DPAS layout is not contiguous on the column dimension. It is strided by 2 because one name represent two rows now.
There is same potential issue for the TF32 operand A layout when the threadsPerWarp=16. Need to fix it. A operand for TF32 with the threadsPerWarp=16: The rows with the same color is represent by the same name in SIMT for 8x8 A matrix. Then each value is strided by 2 for column dimension.
The callstack of the convert layout op lowering for the DPAS layout:
mlir::triton::gpu::intel::DpasEncodingAttr::getSizePerThread Dialect.cpp:111
mlir::triton::gpu::detail::DistributedEncodingTraitInterfaceTraits::Model<mlir::triton::gpu::intel::DpasEncodingAttr>::getSizePerThread TritonGPUAttrInterfaces.h.inc:276
mlir::triton::gpu::DistributedEncodingTrait::getSizePerThread TritonGPUAttrInterfaces.cpp.inc:35
mlir::triton::gpu::getSizePerThread Dialect.cpp:156
mlir::triton::gpu::getContigPerThread Dialect.cpp:186
mlir::triton::gpu::getUniqueContigPerThread Dialect.cpp:206
mlir::triton::gpu::getUniqueContigPerThread Dialect.cpp:198
mlir::triton::getScratchConfigForCvtLayout Allocation.cpp:122
ConvertLayoutOpConversion::lowerDistributedToDistributed ConvertLayoutOpToLLVM.cpp:538
ConvertLayoutOpConversion::matchAndRewrite ConvertLayoutOpToLLVM.cpp:169
mlir::ConvertOpToLLVMPattern::matchAndRewrite Pattern.h:161
We can see the getContigPerThread
calls the getSizePerThread
for the contiguous information of the layout mapping.
In the case threads_per_warp=16, it is true that the layout is contiguous on the column just as return {elemsPerThread, 1}
.
But it is not true when threads_per_warp=32.
There is a case in tt_dot uses the broadcast to make a matrix from a vector. The IR is like this:
There is no issue when the
threadsPerWarp=16
. We can simply replica the name in the SIMT with the corresponding coordinate to broadcast it naturally. But when comes with thethreadsPerWarp=32
, the single name in SIMT represent the 2 rows. It is not travail to use the same name in SIMT. (Maybe we need to use the sub-group-shuffle to move the value to the upper lanes.)Need to fix that if we want to use the
threadsPerWarp=32
on PVC for DPAS.