iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 583 forks source link

Code generation features required to GEMM on Hopper GPU #14106

Open grypp opened 1 year ago

grypp commented 1 year ago

I would like to expand on this project by breaking it down into clear and actionable steps.

Feature 1: Implementing wgmma.mma_async

To leverage the asynchronous nature of wgmma.mma_async, this warp-level gemm instruction must be grouped, committed, and waited upon. It provides the ability to perform larger gemm operations such as 64x128x16. It differs from mma_sync, which is a synchronous warp-level gemm instruction. wgmma.mma_async can read from shared memory. This step marks the initial progress towards targeting new generation tensor cores.

See here for more details about the instruction

Feature 2: Swizzling modes of wgmma.mma_async

wgmma.mma_async supports four swizzling modes, including the option for no swizzling. It is essential for performance support for all these swizzling modes.

See here to see 3 swizzling modes

Feature 3: Supporting TMA

TMA (tensor memory accelerator) is a memory engine that enables asynchronous copying of a layout (tile) from global memory to shared memory. This operation significantly saves registers. In this step, our focus will be on supporting TMA in the iree+mlir framework and establishing its connection with wgmma.mma_async.

IREE host needs to build TMA descriptor and send it to the kernel. In the kernel, a single thread could submit TMA request.

Supporting split/arrive: TMA requires split/arrive barriers (aka mbarriers). We don't have it in mlir. See for examples

Feature 4: CTA Clustering

CTA Clustering is a notable feature of the H100 architecture, allowing the clustering of Cooperative Thread Arrays (CTAs) and access to their shared memory. This step is crucial to enable warp specialization. See for examples

Feature 5: Warp specialization

Splitting consumer/producer warps. Producer fetches data from TMA, consumer does GEMM. See for examples

There is multiple ways of doing codegen here. CUTLASS has three warp specialization, they require different codegen.

Feature 7: Optimized epilogue via TMA

By leveraging TMA, we can optimize the epilogue phase of the computation. This step focuses on enhancing the performance of the epilogue through the effective utilization of TMA.

### Tasks
grypp commented 1 year ago

CUTLASS has 5 GEMM flavors for H100

  Required Features
KernelMultistage 1
KernelTma 1,3
KernelTmaWarpSpecialized 1,3,4,5
KernelTmaWarpSpecializedCooperative 1,3,4,5, 7 (Persistent threads, different codegen)
KernelTmaWarpSpecializedPingpong 1,3,4,5, 7 (different codegen)

Feature 2 is required for every step to avoid bank conflicts. We need it for every GEMM flavor.

IREE Codegen

KernelTma will perhaps give decent performance for limited shapes. This can be a good start for implementing codegen.

KernelMultistage won’t give performance, it doesn't leverage TMA and CTA clusters. However, it can be also a good start for codegen since it requires only feature-1.

To cover all shapes, we need KernelTmaWarpSpecialized, KernelTmaWarpSpecializedCooperative and KernelTmaWarpSpecializedPingpong.

grypp commented 1 year ago

Created #14126 to track Feature 3: Supporting TMA