ROCm / triton

Development repository for the Triton language and compiler
MIT License
89 stars 27 forks source link

[MFMA][FRONTEND] Add more options for forced mfma layout sizes #487

Closed binarman closed 6 months ago

binarman commented 8 months ago

This pr:

per-operation option is used like this:

@triton.jit
def kernel(...):
    ...
    c = tl.dot(a, b, matrix_instr_nonkdim = [4, 64])
    ...

MFMA size heuristic now looks like this:

  1. If dot specific option is set, pick it
  2. If kernel specific option is set, pick it
  3. If the result tile shape is larger than 32x32, pick mfma32
  4. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16
  5. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4
  6. Otherwise, pick mfma4x64 or mfma64x4, depending on what tile fits into matrices
alefimov-amd commented 8 months ago

@zhanglx13 @scxiao If you like this per-dot options idea, I would also like to add istransposed, so we can experiment with various layouts in fused kernels easily.

P.s. I don't think upstream will accept changes to frontend, but backend code will not break even without frontend part, so this change is safe.

zhanglx13 commented 8 months ago

@alefimov-amd I like this idea. I can try to poke Phil to see if he is ok to made some changes in the frontend. Just one question:

  1. Can you give a concrete example that is not achievable by the old kernel option or the heuristics?

Can you also add tests for this new option?

binarman commented 8 months ago

Can you give a concrete example that is not achievable by the old kernel option or the heuristics?

Technically, everything is achievable with heuristic, but you need to recompile compiler to try something new. This option could be useful for experiments.

Lets consider FA with following matrices:

Current heuristic(which is looking only at tensor sizes), will assign following layouts:

qk[4x64]<mfma 4x64, transposed> = tl.dot(q[4x128], k[128x64])
qkv[4x128]<mfma 4x64, transposed> = tl.dot(qk[4x64], v[64x128])

@scxiao proposed different combination:

qk[4x64]<mfma 4x64, transposed> = tl.dot(q[4x128], k[128x64])
qkv[4x128]<mfma 4x4, transposed> = tl.dot(qk[4x64], v[64x128])  <- layout here is different

This could(we need to verify this) be beneficial, because <mfma 4x64> layout is similar to <mfma 4x4> operand layout, and we can skip LDS layout conversion in some cases.

Can you also add tests for this new option?

Will do

zhanglx13 commented 8 months ago

@alefimov-amd @scxiao I'm not sure if it's general to other cases, but I have a proposal for this particular example.

I can see the problem with using mfma4x64 for the second dot is layout incompatibility, therefore LDS is required to do layout conversion. I think there is a way to use the mfma4x64 result directly as the operand of the second dot, which is also using mfma4x64. We need to do two things:

The result matrix of the first dot is 4x64, which has 16 blocks of 4x4 submatrices. Thread 0-3 holds elements in the first block, 4-7 hold elements in the second block, so on so forth. To use it as operand a of mfma4x64, we need all 64 threads to hold elements in the first block, issue the first mfma4x64. Then we want all 64 threads to hold elements in the second block, issue second mfma4x64, and so on so forth. This can be achieved by the CBSZ and ABID flag of mfma, which is able to choose which block (of the 16) to broadcast to the rest. This means we still need to issue 16 mfma4x64 instructions to compute (4x64) x (64x64) --> (4x64), same as using mfma4x4. But the 16 mfma4x64 will have different values for the CBSZ and ABID flag. This is beneficial since one mfma4x64 is much faster than mfma4x4.

I think we are already doing hacky things for some interesting kernels, it should be fine if we make it even more hacky.

scxiao commented 8 months ago

@alefimov-amd @scxiao I'm not sure if it's general to other cases, but I have a proposal for this particular example.

I can see the problem with using mfma4x64 for the second dot is layout incompatibility, therefore LDS is required to do layout conversion. I think there is a way to use the mfma4x64 result directly as the operand of the second dot, which is also using mfma4x64. We need to do two things:

  • Do not swap operands for the second dot. Treat the second dot as a regular gemm
  • use CBSZ and ABID of the mfma instruction to control the broadcast behavior of operand a of the second gemm

The result matrix of the first dot is 4x64, which has 16 blocks of 4x4 submatrices. Thread 0-3 holds elements in the first block, 4-7 hold elements in the second block, so on so forth. To use it as operand a of mfma4x64, we need all 64 threads to hold elements in the first block, issue the first mfma4x64. Then we want all 64 threads to hold elements in the second block, issue second mfma4x64, and so on so forth. This can be achieved by the CBSZ and ABID flag of mfma, which is able to choose which block (of the 16) to broadcast to the rest. This means we still need to issue 16 mfma4x64 instructions to compute (4x64) x (64x64) --> (4x64), same as using mfma4x4. But the 16 mfma4x64 will have different values for the CBSZ and ABID flag. This is beneficial since one mfma4x64 is much faster than mfma4x4.

I think we are already doing hacky things for some interesting kernels, it should be fine if we make it even more hacky.

Yes, this is a good idea, maybe we can finish the current implementation first, then work on this as a separate PR?

binarman commented 8 months ago

I've tried to use this dot-specific option today and found that in FA this attribute does not survive to accelerateMatmul pass.

It is overridden by combine pass, which merges dot and add/mul operations. This approach needs more investigation.

zhanglx13 commented 7 months ago

Let's take your example

Lets consider FA with following matrices: q has shape [4x128] k has shape [128x64] v has shape [64x128]

Ok. Now we have three candidates here:

  1. The old one in triton-mlir branch, which picks mfma4x64 for both dots
  2. The new one implemented in this PR, which allows to pick mfma4x64 for first dot and pick mfma4x4 for second one
  3. The proposed one in comment

Can you verify the following?

And what do you think about the proposed one?

binarman commented 7 months ago

The old one has layout conversion using LDS between the two dots

yes, it has.

The new one does not have layout conversion using LDS between the two dots

This is what I am working at the moment. I verified that it works with simple dot->dot pattern, but there are some problems with FA at the moment. I am experimenting with it in https://github.com/ROCm/triton/pull/504 draft

I did not try to do new proposal yet, I consider this could be a next step after new one is working as intended

binarman commented 6 months ago

I've moved changes from this PR into other PRs, I think it is not needed anymore, closing