NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.66k stars 971 forks source link

[BUG] H100 Kernel with 3 WGs and 384 threads has warp_count set to (4, 1, 1) in cutlass_library #1895

Closed manishucsd closed 3 weeks ago

manishucsd commented 3 weeks ago

H100 Warp-specialized Kernel with 3 WGs and 384 threads has warp_count set to (4, 1, 1). For H100 kernels, the number of WGs/Warps/Threads and their layout on Copy Tile/Math Tile are implicitly fixed by the algorithm. The warp_count in generator.py should/does not do anything. Correct?

However, the below kernel produces the following:

./tools/profiler/cutlass_profiler --m=128 --n=128 --k=8192 --kernels=cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_tnn_align8_warpspecialized_pingpong_epi_nosmem --verification-enabled=false --swizzle_size=1 --output=data.csv --dist=uniform,min:-2.3,max:2.3,scale:-1

Thread 0: warp_group_idx = 0, warp_group_role = 0, producer_warp_role = 0
Thread 1: warp_group_idx = 0, warp_group_role = 0, producer_warp_role = 0
Thread 2: warp_group_idx = 0, warp_group_role = 0, producer_warp_role = 0
...
Thread 381: warp_group_idx = 2, warp_group_role = 2, producer_warp_role = 3
Thread 382: warp_group_idx = 2, warp_group_role = 2, producer_warp_role = 3
Thread 383: warp_group_idx = 2, warp_group_role = 2, producer_warp_role = 3

=============================
  Problem ID: 1
        Provider: CUTLASS
   OperationKind: gemm
       Operation: cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_void_f16_128x128x64_1x2x1_0_tnn_align8_warpspecialized_pingpong_epi_nosmem
          Status: Success
    Verification: OFF
     Disposition: Not verified
       Arguments: --gemm_kind=universal --m=128 --n=128 --k=8192 --A=f16:row --B=f16:column --C=f16:column --D=f16:column  \
                  --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 --batch_count=1 --raster_order=heuristic  \
                  --swizzle_size=1 --op_class=tensorop --accum=f32 --cta_m=128 --cta_n=128 --cta_k=64 --cluster_m=1 --cluster_n=2  \
                  --cluster_k=1 --stages=7 --warps_m=4 --warps_n=1 --warps_k=1 --inst_m=64 --inst_n=128 --inst_k=16 --min_cc=90  \
                  --max_cc=90
           Bytes: 4227072  bytes
           FLOPs: 268468224  flops
           FLOPs/Byte: 63
         Runtime: 0.456127  ms
          Memory: 8.63086 GiB/s
            Math: 588.582 GFLOP/s

See: --warps_m=4 --warps_n=1 --warps_k=1 and also in the --output=data.csv. This has created some confusion while going through the cutlass performance results and potentially create issues with code that consumes these data.csv files.

We requesting a fix for this. Thank you in advance for looking into it.

cc: embg

thakkarV commented 3 weeks ago

Please see: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm_universal_adapter.h#L148

Warp shape is not a concept we use in CUTLASS 3.x. 2.x and 3.x share the CUTLASS profiler, so we do our best to set the "warp shape" to something that is reasonable, however, seems like we have fallen victim to Hyrum's law here. Do you really need this warp shape? it is meaningless in 3.x and should not be relied upon for anything

manishucsd commented 3 weeks ago

Can you please explain how is this victim to Hyrum's law? For e.g. Hyrum's law claims X which is broken here.

Do you really need this warp shape?

Thanks for the link and code-pointer (gemm_universal_adapter.h#L148). We should not be setting warp-count for WS-pingpong, WS-cooperative kernels or any kernel where warp-count is implicitly set by algorithm. However, we expect the following behavior:

  1. The warp_count (m, n, k) reads what the kernel using. Precisely, what is set in the code:
    static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32);
    static constexpr int WarpsInMmaM = 4;
    static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM);
    using WarpCount = cutlass::gemm::GemmShape<WarpsInMmaM, WarpsInMmaN, 1>; // This should be what cutlass_library warp_count be set to. The equations used here should be used somewhere in `generator.py` to set the `tile_descirption.warp_count`. 
  2. When someone try to run --warps_m=4 --warps_n=1,2 --warps_k=1 with WS-pingpong, profiler should run nothing, there is no kernel like that.
thakkarV commented 3 weeks ago
  1. Yes, and the profiler reporting the adapter's best approximation of warp count is the desired behavior here.
  2. Why is that? For ping pong schedule kernels, the tiled MMA has 128 threads in which should map to a "warp shape" of 4x1x1, a shape that is covered by the --warps_m=4 --warps_n=1,2 --warps_k=1 parameters. For cooperative, the warp shape would be 4x2x1 (still inaccurate but at least the size matches)
manishucsd commented 3 weeks ago
  1. The reflection of what kernel is setting as WarpCount in the adaptor and what we see in the library created out of cutlass kernel seems correct.
  2. I understand now the source of this confusion that triggered series of discussions here that we are trying to resolve. The problem is in the definition of warp_m,n,k. The warp_m,n,k is not the number of warps for the kernel, it is not number of math warps (as both pingpong and cooperative have 8 warps for math per CTA, i.e., 2 warpgroup), it is warp_count per unique output tile.

Image

For Pingpong, warp_count for each 128x128 tile is 4x1x1.

Image

For Cooperative, warp_count for 128x128 output tile is 8x1x1, but is reported as 4x2x1.

The discussion here resolves my confusion, resolving it in the code with proper definition is P2 and not super essential. Closing the issue. Thanks for the discussion. Feel free to add more thoughts. I have bookmarked this discussion :).

thakkarV commented 3 weeks ago

Yep! Your understanding is correct. We simply set it to the number of warps in the tiled mma. But let me emphasize again that this warp shape should not be used for anything in 3.x API kernels nor should they be relied upon in general. We just happen to have done a decent enough job for Hopper kernels in this case to make some sense out of them but as you point out it's still not perfectly accurate for cooperative kernels.