Open grypp opened 1 year ago
CUTLASS has 5 GEMM flavors for H100
Required Features | |
---|---|
KernelMultistage | 1 |
KernelTma | 1,3 |
KernelTmaWarpSpecialized | 1,3,4,5 |
KernelTmaWarpSpecializedCooperative | 1,3,4,5, 7 (Persistent threads, different codegen) |
KernelTmaWarpSpecializedPingpong | 1,3,4,5, 7 (different codegen) |
Feature 2 is required for every step to avoid bank conflicts. We need it for every GEMM flavor.
IREE Codegen
KernelTma will perhaps give decent performance for limited shapes. This can be a good start for implementing codegen.
KernelMultistage won’t give performance, it doesn't leverage TMA and CTA clusters. However, it can be also a good start for codegen since it requires only feature-1.
To cover all shapes, we need KernelTmaWarpSpecialized, KernelTmaWarpSpecializedCooperative and KernelTmaWarpSpecializedPingpong.
Created #14126 to track Feature 3: Supporting TMA
I would like to expand on this project by breaking it down into clear and actionable steps.
Feature 1: Implementing wgmma.mma_async
To leverage the asynchronous nature of
wgmma.mma_async
, this warp-level gemm instruction must be grouped, committed, and waited upon. It provides the ability to perform larger gemm operations such as 64x128x16. It differs frommma_sync
, which is a synchronous warp-level gemm instruction.wgmma.mma_async
can read from shared memory. This step marks the initial progress towards targeting new generation tensor cores.See here for more details about the instruction
Feature 2: Swizzling modes of wgmma.mma_async
wgmma.mma_async
supports four swizzling modes, including the option for no swizzling. It is essential for performance support for all these swizzling modes.See here to see 3 swizzling modes
Feature 3: Supporting TMA
TMA (tensor memory accelerator) is a memory engine that enables asynchronous copying of a layout (tile) from global memory to shared memory. This operation significantly saves registers. In this step, our focus will be on supporting TMA in the iree+mlir framework and establishing its connection with
wgmma.mma_async
.IREE host needs to build TMA descriptor and send it to the kernel. In the kernel, a single thread could submit TMA request.
Supporting split/arrive: TMA requires split/arrive barriers (aka
mbarriers
). We don't have it in mlir. See for examplesFeature 4: CTA Clustering
CTA Clustering is a notable feature of the H100 architecture, allowing the clustering of Cooperative Thread Arrays (CTAs) and access to their shared memory. This step is crucial to enable warp specialization. See for examples
Feature 5: Warp specialization
Splitting consumer/producer warps. Producer fetches data from TMA, consumer does GEMM. See for examples
There is multiple ways of doing codegen here. CUTLASS has three warp specialization, they require different codegen.
Feature 7: Optimized epilogue via TMA
By leveraging TMA, we can optimize the epilogue phase of the computation. This step focuses on enhancing the performance of the epilogue through the effective utilization of TMA.