This PR adds support for scheduling and generating indices for stmatrix instruction following a Mma operation (with a cast operator in the middle).
We support TT and TN layouts for Hopper RS. This should work with NT as well for HopperSS.
We support using stmatrix .x2 and .x4 (16x8, 16x16)
TMA with different swizzle factors were tested.
We don't support using stmatrix .x1 (8x8)
Scheduling the output of stmatrix
Scheduling relies heavily on the function mma_utils::MmaSwizzler::scheduleMmaOutputAllocation as shown in the figures below.
For scheduling when lowering to stmatrix.x4 it is slightly modified to:
Index generation for the output of stmatrix
Index generation (for the output of stmatrix) computes the offset into the output shared memory. I'll explain this with an example of storing a 64(M), 16(N) piece of memory using stmatrix.x2 (16x8 tiles).
Please note that the assumption for index generation is that we have 128 threads (a warp group). And there is a for-loop (if require) which invokes the stmatrix.
In the first-iteration of the for-loop the 128 threads stores the warp box 0 then in the next iteration stores warp box 1. Each warp box has 4 warps, each issuing a stmatrix. Each individual stmatrix call stores a tile box of memory. Thus there are 4 tile boxes in each warp box.
Thus the index/offset location in shared memory is computed:
offset of the warp box + offset of the tile box in the warp box + offset of the thread in the tile box.
Follow-up PRs:
See if we can support stmatrix (8x8) without too many changes
This PR adds support for scheduling and generating indices for
stmatrix
instruction following a Mma operation (with a cast operator in the middle).We support TT and TN layouts for Hopper RS. This should work with NT as well for HopperSS. We support using stmatrix .x2 and .x4 (16x8, 16x16) TMA with different swizzle factors were tested.
We don't support using stmatrix .x1 (8x8)
Scheduling the output of
stmatrix
Scheduling relies heavily on the function
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation
as shown in the figures below.For scheduling when lowering to
stmatrix.x4
it is slightly modified to:Index generation for the output of
stmatrix
Index generation (for the output of stmatrix) computes the offset into the output shared memory. I'll explain this with an example of storing a 64(M), 16(N) piece of memory using stmatrix.x2 (16x8 tiles). Please note that the assumption for index generation is that we have 128 threads (a warp group). And there is a for-loop (if require) which invokes the stmatrix.
In the first-iteration of the for-loop the 128 threads stores the
warp box 0
then in the next iteration storeswarp box 1
. Each warp box has 4 warps, each issuing a stmatrix. Each individual stmatrix call stores a tile box of memory. Thus there are 4 tile boxes in each warp box.Thus the index/offset location in shared memory is computed: offset of the warp box + offset of the tile box in the warp box + offset of the thread in the tile box.
Follow-up PRs: