NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.37k stars 904 forks source link

[QST] Why hopper-mixed-gemm's Bandwidth Utilization only have ~9% MBU in H100 SXM5? #1794

Open ZZBoom opened 1 week ago

ZZBoom commented 1 week ago

Hello, here are my test logs

# command line:
CUDA_VISIBLE_DEVICES=7 ./examples/55_hopper_mixeed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=16 --n=6144 --k=2048 --g=128 --mode=1

# Running results:
Running in group scale mode.
Disposition: Passed
Problem Size: 16x6144x2048x1
Avg runtime: 0.0213007 ms
GFLOPS: 18903.3

In this example, I think the total memory accessed is A(fp8) + B(int4) + scale(half) + C(half) + D(half)

So I compute the Bandwidth utilization, ~10.8%

= (m*k*1 + n*k/2 + n*k/g*2 + m*n*2 + m*n*2) / h100_peak_bw(GB/s) * 1e6 / Avg runtime
= (16*2048+6144*2048/2 + 6144*2048/128*2 + 16*6144*2 + 16*6144*2) / 3000 / 1e6  / 0.0213007
= 10.8%
IwakuraRein commented 1 week ago

Hi. This gemm is more computation bound rather than memory bound. There is a huge overhead of typecasting (and multiplication since you're using mode=1). Also, your M size, 16, is quite small. This example swaps the Matrix A and Matrix B so M goes to the N dimension. If the input matrix M is smaller than TileShape N then the GEMM is not saturated, and it's expected to have low TFLOPS and thus a low bandwidth utilization.

azhurkevich commented 1 week ago

@ZZBoom please setup your tile shape as using TileShape = Shape<_128,_16,_128;. Kernel is partially bound by conversion logic like @IwakuraRein mentioned. So it's not just load -> MMA -> store, not a pure mem band bound case. So above mem bandwidth calculation will not be entirely correct in this case.

Please play around with different tile shapes, kernel schedules and epilogue schedules to see what works best for your use case. Always feel free to reach out for help

ZZBoom commented 1 week ago

@IwakuraRein @azhurkevich Thank you for your prompt response!

After adjusting the TileShape to Shape<_128, _16, _128>, the performance has significantly improved. The processing time has been reduced by approximately 40%. Here are the updated running logs:

Running in group scale mode.
Disposition: Passed
Problem Size: 16x6144x2048x1
Avg runtime: 0.0135779 ms
GFLOPS: 29655.1

I conducted two GEMM performance tests with the following parameters:

Test1: GEMM with dimensions m,n,k = 16,2560,6144, using TileShape = 128x16x128, computation time of 0.0328 ms. Test2: GEMM with dimensions m,n,k = 16,8192,6144, using TileShape = 128x16x128, computation time of 0.033 ms.

It was surprising to see that despite the increase in the GEMM 'n' dimension from 2560 to 8192 (3.2 times larger), the computation time remained virtually unchanged.

This observation suggests that the SM occupancy might be less than the maximum SM count, which is 132 in this case. To address this, I attempted to adjust the tile shape for 'M' from 128 to 64. However, this change resulted in a compilation error.

So, I need some helps, thanks a lot.

IwakuraRein commented 1 week ago

@ZZBoom To use 64 for tile size M, plz use cutlass::gemm::KernelTmaWarpSpecializedCooperativePingpong for KernelSchedule and cutlass::epilogue::TmaWarpSpecialized for EpilogueSchedule. Cooperative schedule assigns two warp groups to one tile. Since the M dimension of tensor core instruction is always 64, two warp groups result in at least 128 in the M dimension. Pingpong schedule assign one warp group to each tile so its tile size M can be as low as 64.