NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.53k stars 943 forks source link

[BUG] Trying to optimize mixed input for kernels #1868

Open NihalPotdar opened 1 week ago

NihalPotdar commented 1 week ago

Describe the bug I was reading through the cutlass mixed precision kernels, https://github.com/NVIDIA/cutlass/blob/cc3c29a81a140f7b97045718fb88eb0664c37bd7/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp#L552. Written this way, it implicitly checks that group size > tile shape K, which should not matter since we account for this by a reload factor later in the code: https://github.com/NVIDIA/cutlass/blob/cc3c29a81a140f7b97045718fb88eb0664c37bd7/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp#L765.

Should this be reversed? That is, the tile shape K > group size. This would benefit gemms which have large K dimension. implementable && (args.group_size == K || ((size<2>(TileShape{})) % args.group_size == 0));.

azhurkevich commented 6 days ago

We simply made it compile time parameter for perf CC @IwakuraRein

IwakuraRein commented 6 days ago

@azhurkevich No, group_size is still a runtime argument at this stage. @NihalPotdar The condition (args.group_size % size<2>(TileShape{})) == 0 is too strong. I aggree with you. It forces the whole row of a matrix A tile to use the same scale value, but I think it's reasonable to apply this only to the row in a wgmma instruction. (since SmemLayoutAtomScale's second dimension is 1). Therefore, the condition should be size<2>(TileShape{}) % args.group_size == 0 && args.group_size % gmma_shape_k == 0. For instance, if a tile is 64x128x64, a thread will launch four 64x128x16 wgmma in total to compute a tile, and the group size can be as low as 16 IMO. Let me delve into the codes and see if this can be fixed.

IwakuraRein commented 6 days ago

@NihalPotdar Oh sry I made a mistake. SmemLayoutAtomScale and ScaleTileShape's second dimension being 1 is broadcasting 1 scale value to the whole K dimension of a tile. To make the granularity become the wgmma's K dimension, ScaleTileShape needs to be updated, and that also requires the group_size to be a compile time value so currently there is no easy fix for this.

azhurkevich commented 6 days ago

@NihalPotdar can you please explain if this is a blocker for you, in details.

NihalPotdar commented 6 days ago

@azhurkevich yes it is. I was running some matmuls and benchmarking their performance using this code for uint4 and fp16 datatypes. However, when the K dimension is large and the other dimensions are smaller (say (m*4096x8192)), this leads to suboptimal performance and bandwidth utilization. For these cases, the peak bandwidth utilization I have observed is close to 30%. My thinking is that if the tile size K was allowed to be larger than group size, this might alleviate some of these issues.

azhurkevich commented 6 days ago

@NihalPotdar can you please provide more detailed data of what you've encountered and what are your expectations pls. Thank you

NihalPotdar commented 6 days ago

@azhurkevich sure. So, I was working with this example code. If we set the mmaType to float16 (cutlass::half_t) and the quantType to uint4 (cutlass::uint4).

For the problem size, M=16, N=2560, K=8192. This problem size is memory bound, dominated by how fast we can read from the HBM into the SMs. I found the "most optimal results" when I get tile_m = 64, tile_n = 16, and tile_k = 64 with the KernelTmaWarpSpecializedMixedInput kernel. The group size I am using for testing in this case is 128.

However, even for these optimal results, I used ncu to profile that the maximum dram utilization is ~30%. This is not great and my hunch is that increasing the tile_k size where tile_k > group_size can help since (I think) this problem size is limited by the number of mac loop iterations. StreamK will not help in this case due to the overheads associated with that scheduling strategy for a small problem size, so parallelizing across K is not an option.

^ so being able to set the tile_k > group_size would be great!

azhurkevich commented 6 days ago

@NihalPotdar taking into account the fact that most likely you are taking advantage of default SwapAB. Where M corresponds with second operand of tile shape (hence you can do 16) and N with first one. MN are responsible for parallelization across CTAs. With your problem shape you are currently launching 40 CTAs on a 132 SM GPU hence you are seeing such low utilization. You are correct that StreamK adds additional latencies. However, StreamK is quite well positioned for your problem shape. Typically I see some benefit and getting better on K>8k. Also please try SplitK, it has less overheads vs StreamK. This should help you.

IwakuraRein commented 5 days ago

@NihalPotdar Sry I don't quite understand why making tile_k > group_size will help in this case. When tile_k > group_size, the scale matrix will become larger and thus you're loading more data from device memory. Also, inside a tile, you're not broadcasting 1 scale value to the whole k dimension so you're also loading more data from shared memory as well. As for the computation, the number of multiplications will always be TileM x TileK (since this kernel does scaling before the MMA), which is irrelevant to the group size. IMHO the only reason why small group size is needed is to reduce the accuracy loss from the quantization. It won't improve the latency of the kernel.