ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 23 forks source link

[MFMA] Support 64x4 and 4x64 tile size #469

Closed binarman closed 6 months ago

binarman commented 6 months ago

This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64. Both of them uses mfma 4x4 instructions.

zhanglx13 commented 6 months ago

@scxiao @vgokhale The heuristic for picking mfma instruction size is as follows

  1. If the result tile shape is larger than 32x32, pick mfma32
  2. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16
  3. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4
  4. Otherwise, pick mfma4x64 or mfma64x4

However, in the case of FA decode kernel, the tile shape is 16x128. And mfma16 will be picked according to the heuristic. The tile shape refers to the result tensor shape of tt.dot. This heuristic does not take num_warps into consideration. But we do not have warp layout information when choosing mfma dimensions. Therefore, the only solution here is enable some user input to enforce the choice of mfma4x64 here.

@alefimov-amd In the next PR, can you change chooseMfmaDimensions to pick 4x64 or 64x4 based on the tile shape when matrix_instr_nonkdim is 4?

scxiao commented 6 months ago

chooseMfmaDimensions

Can you specify a unique value of chooseMfmaDimensions to choose mfma4x64 and mfma64x4, like 464 and 644?

binarman commented 6 months ago

In the next PR, can you change chooseMfmaDimensions to pick 4x64 or 64x4 based on the tile shape when matrix_instr_nonkdim is 4?

sure

Can you specify a unique value of chooseMfmaDimensions to choose mfma4x64 and mfma64x4, like 464 and 644?

This is useful idea, thank you!