Closed manishucsd closed 3 weeks ago
Please see: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm_universal_adapter.h#L148
Warp shape is not a concept we use in CUTLASS 3.x. 2.x and 3.x share the CUTLASS profiler, so we do our best to set the "warp shape" to something that is reasonable, however, seems like we have fallen victim to Hyrum's law here. Do you really need this warp shape? it is meaningless in 3.x and should not be relied upon for anything
Can you please explain how is this victim to Hyrum's law? For e.g. Hyrum's law claims X which is broken here.
Do you really need this warp shape?
Thanks for the link and code-pointer (gemm_universal_adapter.h#L148). We should not be setting warp-count for WS-pingpong, WS-cooperative kernels or any kernel where warp-count is implicitly set by algorithm. However, we expect the following behavior:
static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32);
static constexpr int WarpsInMmaM = 4;
static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>; // This should be what cutlass_library warp_count be set to. The equations used here should be used somewhere in `generator.py` to set the `tile_descirption.warp_count`.
--warps_m=4 --warps_n=1,2 --warps_k=1
with WS-pingpong, profiler should run nothing, there is no kernel like that. --warps_m=4 --warps_n=1,2 --warps_k=1
parameters. For cooperative, the warp shape would be 4x2x1 (still inaccurate but at least the size matches)WarpCount
in the adaptor and what we see in the library created out of cutlass kernel seems correct. warp_m,n,k
. The warp_m,n,k
is not the number of warps for the kernel, it is not number of math warps (as both pingpong and cooperative have 8 warps for math per CTA, i.e., 2 warpgroup), it is warp_count per unique output tile. For Pingpong, warp_count for each 128x128 tile is 4x1x1
.
For Cooperative, warp_count for 128x128 output tile is 8x1x1
, but is reported as 4x2x1
.
The discussion here resolves my confusion, resolving it in the code with proper definition is P2 and not super essential. Closing the issue. Thanks for the discussion. Feel free to add more thoughts. I have bookmarked this discussion :).
Yep! Your understanding is correct. We simply set it to the number of warps in the tiled mma. But let me emphasize again that this warp shape should not be used for anything in 3.x API kernels nor should they be relied upon in general. We just happen to have done a decent enough job for Hopper kernels in this case to make some sense out of them but as you point out it's still not perfectly accurate for cooperative kernels.
H100 Warp-specialized Kernel with 3 WGs and 384 threads has warp_count set to (4, 1, 1). For H100 kernels, the number of WGs/Warps/Threads and their layout on Copy Tile/Math Tile are implicitly fixed by the algorithm. The
warp_count
ingenerator.py
should/does not do anything. Correct?However, the below kernel produces the following:
See:
--warps_m=4 --warps_n=1 --warps_k=1
and also in the--output=data.csv
. This has created some confusion while going through the cutlass performance results and potentially create issues with code that consumes these data.csv files.We requesting a fix for this. Thank you in advance for looking into it.
cc: embg