Open ThomasRaoux opened 2 years ago
Nice, this is also a thing I'm looking for!
So far we have been using ContractionOpInterface but we most likely need something more powerful.
There are some methods in ContractionOpInterface may help, e.g., isColumnMajorMatmul
, isRowMajorMatmul
, though there are no such matmul ops currently. We may add some named ops, and implement a transform to fold transpose + some_matmul
into other_matmul
.
Nice, this is also a thing I'm looking for!
Nice :)
There are some methods in ContractionOpInterface may help, e.g.,
isColumnMajorMatmul
,isRowMajorMatmul
, though there are no such matmul ops currently. We may add some named ops, and implement a transform to foldtranspose + some_matmul
intoother_matmul
.
Yes that's the tradeoff, I wonder if we could have a new named op that take an additional permutation to represent the transposition. My concern is that it is hard to predict all the variations we will need. Having powerful matchers may be more flexible.
We can define a new matmul op that has different indexing maps, and make it have LinalgContractionOpInterface.
Adding new methods to LinalgContractionOpInterface
(e.g., getAppliedPermutationMap) may provide us more powerful matchers?
Alright, so what is the consensus new named ops or make it work with generic ops? @ThomasRaoux @hanhanW @MaheshRavishankar
Alright, so what is the consensus new named ops or make it work with generic ops? @ThomasRaoux @hanhanW @MaheshRavishankar
I’m not very familiar with the names op infrastructure. If there is an easy way to add an op that would support the different combination it is probably better. If not I’d start with a generic op. If we need one op per combination I think this is going to be annoying to handle.
@nicolasvasilache @hanhanW, any idea hiw hard it would be (if possible) to have a matmul/batchmatmul named op with an attribute to encode the transposition information ?
It depends on the "kind" of ops to support. With GEMM, transpose(A) * B
, A * transpose(B)
, transpose(A) * transpose(B)
are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add.
But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more names ops). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.
It depends on the "kind" of ops to support. With GEMM,
transpose(A) * B
,A * transpose(B)
,transpose(A) * transpose(B)
are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add. But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.
I agree, that’s pretty much what the cuda backend currently does. I think we can clean it up and have improved common matchers then
It depends on the "kind" of ops to support. With GEMM,
transpose(A) * B
,A * transpose(B)
,transpose(A) * transpose(B)
are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add. But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.I agree, that’s pretty much what the cuda backend currently does. I think we can clean it up and have improved common matchers then
Looked at what the CUDA backend does. That works (might need to be made a bit tighter, but apart from that its fine).
Re generic matching, I have started to add some support in the context of the transform dialect (see https://github.com/google/iree/blob/main/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp#L142)
You could extend this to support permutations (either a filter for the specific one you want to match or a generator that can capture "any" permutation and return it).
Then you should be able to match any generic whose behavior matches that of a named op, modulo some permutation.
For prototyping, linalg.generic
op is good enough. We may need better matchers to make it work generally in IREE.
have a matmul/batchmatmul named op with an attribute to encode the transposition information
Having transpose + matmul
pattern in IREE might need some changes around dispatch function formation. Because we'd like to make them be in the same dispatch region. It makes sense to have a single named op with these variances. What you're looking for is to embed these information into the named op (which might carry an indexing_maps attribute). We can have a method (e.g., getMatmulNamedOp(linalg::ContractionOpInterface op, AffineMap map
) to return the named op and use it in the pass. With this approach, we don't need changes in dispatch region formation. All the matmul named ops follow linalg::ContractionOpInterface
, so we can extend what we have in configurations by looking into the indexing maps.
@hanhanW @ThomasRaoux @MaheshRavishankar @nicolasvasilache Old issue alert! Still active? P1 or P2?
It looks like something related to @okkwon 's recent work. And they're landed with default=false in IREE?
This is an old issue indeed. This is fusing transposes with its consumer instead of its producers (which is what happens today). Its not what @okkwon is working on. Ill let Thomas decide the actual priority, but for me this is a low priority at the moment.
This is slightly different than what Okwan is doing as the idea was to fuse transpose with consumers. Right now we decided not to do it. I think we will eventually still want some support for it but probably not for a while. I agree I think it should be low priority at this point.
@nirvedhmeshram, I'm removing you from assignee since I don't expect you to work on it any time soon
This has been discussed in the past and we are starting to see the missing fusion of matmul+tranpose being one of the top bottleneck for Bert model. At this point we need should be able to add a first version without significant design changes in IREE. Here are the steps we should take:
Add a fusion part that can fuse named matmul/batchMatmul with generic op with some basic control (like the rest of the fusion patterns), there should be significant amount of re-use from the rest of the fusion patterns.
Plumb this pattern in IREE linalg fusion pass, for simplicity it should be done after genericOp fusion so that we don't fuse anything we don't mean to into the matmul ops. Then the heuristic can be conservative first and only allow the obvious wins. (for instance the most common one in bert is batchmatmul mxbxk * bxkxn -> bxmxn)
Make sure all the backends handle the new genericOp and that they go through the same codegeneration flow as regular matmul. This most likely will require matcher to identify that those are transposed matmul. So far we have been using ContractionOpInterface but we most likely need something more powerful.