NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Add support to store the outputs of the Mma operator using stmatrix #3395

Open protonu opened 1 week ago

protonu commented 1 week ago

This PR adds support for scheduling and generating indices for stmatrix instruction following a Mma operation (with a cast operator in the middle).

image

We support TT and TN layouts for Hopper RS. This should work with NT as well for HopperSS. We support using stmatrix .x2 and .x4 (16x8, 16x16) TMA with different swizzle factors were tested.

We don't support using stmatrix .x1 (8x8)

Scheduling the output of stmatrix

Scheduling relies heavily on the function mma_utils::MmaSwizzler::scheduleMmaOutputAllocation as shown in the figures below.

image

For scheduling when lowering to stmatrix.x4 it is slightly modified to:

image

Index generation for the output of stmatrix

Index generation (for the output of stmatrix) computes the offset into the output shared memory. I'll explain this with an example of storing a 64(M), 16(N) piece of memory using stmatrix.x2 (16x8 tiles). Please note that the assumption for index generation is that we have 128 threads (a warp group). And there is a for-loop (if require) which invokes the stmatrix.

image

In the first-iteration of the for-loop the 128 threads stores the warp box 0 then in the next iteration stores warp box 1. Each warp box has 4 warps, each issuing a stmatrix. Each individual stmatrix call stores a tile box of memory. Thus there are 4 tile boxes in each warp box.

Thus the index/offset location in shared memory is computed: offset of the warp box + offset of the tile box in the warp box + offset of the thread in the tile box.

Follow-up PRs:

  1. See if we can support stmatrix (8x8) without too many changes
  2. Reducing bank conflicts.
protonu commented 2 days ago

!build