Closed binarman closed 6 months ago
@zhanglx13 @scxiao
If you like this per-dot options idea, I would also like to add istransposed
, so we can experiment with various layouts in fused kernels easily.
P.s. I don't think upstream will accept changes to frontend, but backend code will not break even without frontend part, so this change is safe.
@alefimov-amd I like this idea. I can try to poke Phil to see if he is ok to made some changes in the frontend. Just one question:
Can you also add tests for this new option?
Can you give a concrete example that is not achievable by the old kernel option or the heuristics?
Technically, everything is achievable with heuristic, but you need to recompile compiler to try something new. This option could be useful for experiments.
Lets consider FA with following matrices:
Current heuristic(which is looking only at tensor sizes), will assign following layouts:
qk[4x64]<mfma 4x64, transposed> = tl.dot(q[4x128], k[128x64])
qkv[4x128]<mfma 4x64, transposed> = tl.dot(qk[4x64], v[64x128])
@scxiao proposed different combination:
qk[4x64]<mfma 4x64, transposed> = tl.dot(q[4x128], k[128x64])
qkv[4x128]<mfma 4x4, transposed> = tl.dot(qk[4x64], v[64x128]) <- layout here is different
This could(we need to verify this) be beneficial, because <mfma 4x64>
layout is similar to <mfma 4x4>
operand layout, and we can skip LDS layout conversion in some cases.
Can you also add tests for this new option?
Will do
@alefimov-amd @scxiao I'm not sure if it's general to other cases, but I have a proposal for this particular example.
I can see the problem with using mfma4x64 for the second dot is layout incompatibility, therefore LDS is required to do layout conversion. I think there is a way to use the mfma4x64 result directly as the operand of the second dot, which is also using mfma4x64. We need to do two things:
The result matrix of the first dot is 4x64, which has 16 blocks of 4x4 submatrices. Thread 0-3 holds elements in the first block, 4-7 hold elements in the second block, so on so forth. To use it as operand a of mfma4x64, we need all 64 threads to hold elements in the first block, issue the first mfma4x64. Then we want all 64 threads to hold elements in the second block, issue second mfma4x64, and so on so forth. This can be achieved by the CBSZ and ABID flag of mfma, which is able to choose which block (of the 16) to broadcast to the rest. This means we still need to issue 16 mfma4x64 instructions to compute (4x64) x (64x64) --> (4x64), same as using mfma4x4. But the 16 mfma4x64 will have different values for the CBSZ and ABID flag. This is beneficial since one mfma4x64 is much faster than mfma4x4.
I think we are already doing hacky things for some interesting kernels, it should be fine if we make it even more hacky.
@alefimov-amd @scxiao I'm not sure if it's general to other cases, but I have a proposal for this particular example.
I can see the problem with using mfma4x64 for the second dot is layout incompatibility, therefore LDS is required to do layout conversion. I think there is a way to use the mfma4x64 result directly as the operand of the second dot, which is also using mfma4x64. We need to do two things:
- Do not swap operands for the second dot. Treat the second dot as a regular gemm
- use CBSZ and ABID of the mfma instruction to control the broadcast behavior of operand a of the second gemm
The result matrix of the first dot is 4x64, which has 16 blocks of 4x4 submatrices. Thread 0-3 holds elements in the first block, 4-7 hold elements in the second block, so on so forth. To use it as operand a of mfma4x64, we need all 64 threads to hold elements in the first block, issue the first mfma4x64. Then we want all 64 threads to hold elements in the second block, issue second mfma4x64, and so on so forth. This can be achieved by the CBSZ and ABID flag of mfma, which is able to choose which block (of the 16) to broadcast to the rest. This means we still need to issue 16 mfma4x64 instructions to compute (4x64) x (64x64) --> (4x64), same as using mfma4x4. But the 16 mfma4x64 will have different values for the CBSZ and ABID flag. This is beneficial since one mfma4x64 is much faster than mfma4x4.
I think we are already doing hacky things for some interesting kernels, it should be fine if we make it even more hacky.
Yes, this is a good idea, maybe we can finish the current implementation first, then work on this as a separate PR?
I've tried to use this dot-specific option today and found that in FA this attribute does not survive to accelerateMatmul
pass.
It is overridden by combine pass, which merges dot and add/mul operations. This approach needs more investigation.
Let's take your example
Lets consider FA with following matrices: q has shape [4x128] k has shape [128x64] v has shape [64x128]
Ok. Now we have three candidates here:
Can you verify the following?
And what do you think about the proposed one?
The old one has layout conversion using LDS between the two dots
yes, it has.
The new one does not have layout conversion using LDS between the two dots
This is what I am working at the moment. I verified that it works with simple dot->dot pattern, but there are some problems with FA at the moment. I am experimenting with it in https://github.com/ROCm/triton/pull/504 draft
I did not try to do new proposal yet, I consider this could be a next step after new one is working as intended
I've moved changes from this PR into other PRs, I think it is not needed anymore, closing
This pr:
matrix_instr_nonkdim
: 464 corresponds 4(M)x64(N), 644 corresponds 64(M)x4(N)per-operation option is used like this:
MFMA size heuristic now looks like this: