iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.79k stars 603 forks source link

Add fusion of transpose with matmul/batchmatmul named ops #8827

Open ThomasRaoux opened 2 years ago

ThomasRaoux commented 2 years ago

This has been discussed in the past and we are starting to see the missing fusion of matmul+tranpose being one of the top bottleneck for Bert model. At this point we need should be able to add a first version without significant design changes in IREE. Here are the steps we should take:

  1. Add a fusion part that can fuse named matmul/batchMatmul with generic op with some basic control (like the rest of the fusion patterns), there should be significant amount of re-use from the rest of the fusion patterns.

  2. Plumb this pattern in IREE linalg fusion pass, for simplicity it should be done after genericOp fusion so that we don't fuse anything we don't mean to into the matmul ops. Then the heuristic can be conservative first and only allow the obvious wins. (for instance the most common one in bert is batchmatmul mxbxk * bxkxn -> bxmxn)

  3. Make sure all the backends handle the new genericOp and that they go through the same codegeneration flow as regular matmul. This most likely will require matcher to identify that those are transposed matmul. So far we have been using ContractionOpInterface but we most likely need something more powerful.

hanhanW commented 2 years ago

Nice, this is also a thing I'm looking for!

So far we have been using ContractionOpInterface but we most likely need something more powerful.

There are some methods in ContractionOpInterface may help, e.g., isColumnMajorMatmul, isRowMajorMatmul, though there are no such matmul ops currently. We may add some named ops, and implement a transform to fold transpose + some_matmul into other_matmul.

ThomasRaoux commented 2 years ago

Nice, this is also a thing I'm looking for!

Nice :)

There are some methods in ContractionOpInterface may help, e.g., isColumnMajorMatmul, isRowMajorMatmul, though there are no such matmul ops currently. We may add some named ops, and implement a transform to fold transpose + some_matmul into other_matmul.

Yes that's the tradeoff, I wonder if we could have a new named op that take an additional permutation to represent the transposition. My concern is that it is hard to predict all the variations we will need. Having powerful matchers may be more flexible.

hanhanW commented 2 years ago

We can define a new matmul op that has different indexing maps, and make it have LinalgContractionOpInterface.

https://github.com/llvm/llvm-project/blob/c807141d27e6e60bf5829009c5af195f38205966/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L159-L229

Adding new methods to LinalgContractionOpInterface (e.g., getAppliedPermutationMap) may provide us more powerful matchers?

nirvedhmeshram commented 2 years ago

Alright, so what is the consensus new named ops or make it work with generic ops? @ThomasRaoux @hanhanW @MaheshRavishankar

ThomasRaoux commented 2 years ago

Alright, so what is the consensus new named ops or make it work with generic ops? @ThomasRaoux @hanhanW @MaheshRavishankar

I’m not very familiar with the names op infrastructure. If there is an easy way to add an op that would support the different combination it is probably better. If not I’d start with a generic op. If we need one op per combination I think this is going to be annoying to handle.

@nicolasvasilache @hanhanW, any idea hiw hard it would be (if possible) to have a matmul/batchmatmul named op with an attribute to encode the transposition information ?

MaheshRavishankar commented 2 years ago

It depends on the "kind" of ops to support. With GEMM, transpose(A) * B, A * transpose(B), transpose(A) * transpose(B) are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add. But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more names ops). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.

ThomasRaoux commented 2 years ago

It depends on the "kind" of ops to support. With GEMM, transpose(A) * B, A * transpose(B), transpose(A) * transpose(B) are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add. But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.

I agree, that’s pretty much what the cuda backend currently does. I think we can clean it up and have improved common matchers then

MaheshRavishankar commented 2 years ago

It depends on the "kind" of ops to support. With GEMM, transpose(A) * B, A * transpose(B), transpose(A) * transpose(B) are just 3 more ops. If all of these just need an extra batch-dimension as the most significant dimension, that's fairly straight-forward to add. But if the batch dimension is not always the outermost, then it would probably be easier to just use generic ops (it would need 12 more). Not sure deciding configuration based on named op then gives us much. Instead it should be fairly easy to get which dimension is the batch dimension, the preserved dimension and the contracted dimensions by looking at the indexing maps. We can just use that to drive compilation and heuristics. So just using generic ops makes the most sense to me.

I agree, that’s pretty much what the cuda backend currently does. I think we can clean it up and have improved common matchers then

Looked at what the CUDA backend does. That works (might need to be made a bit tighter, but apart from that its fine).

nicolasvasilache commented 2 years ago

Re generic matching, I have started to add some support in the context of the transform dialect (see https://github.com/google/iree/blob/main/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp#L142)

You could extend this to support permutations (either a filter for the specific one you want to match or a generator that can capture "any" permutation and return it).

Then you should be able to match any generic whose behavior matches that of a named op, modulo some permutation.

hanhanW commented 2 years ago

For prototyping, linalg.generic op is good enough. We may need better matchers to make it work generally in IREE.

have a matmul/batchmatmul named op with an attribute to encode the transposition information

Having transpose + matmul pattern in IREE might need some changes around dispatch function formation. Because we'd like to make them be in the same dispatch region. It makes sense to have a single named op with these variances. What you're looking for is to embed these information into the named op (which might carry an indexing_maps attribute). We can have a method (e.g., getMatmulNamedOp(linalg::ContractionOpInterface op, AffineMap map) to return the named op and use it in the pass. With this approach, we don't need changes in dispatch region formation. All the matmul named ops follow linalg::ContractionOpInterface, so we can extend what we have in configurations by looking into the indexing maps.

allieculp commented 2 years ago

@hanhanW @ThomasRaoux @MaheshRavishankar @nicolasvasilache Old issue alert! Still active? P1 or P2?

hanhanW commented 2 years ago

It looks like something related to @okkwon 's recent work. And they're landed with default=false in IREE?

MaheshRavishankar commented 2 years ago

This is an old issue indeed. This is fusing transposes with its consumer instead of its producers (which is what happens today). Its not what @okkwon is working on. Ill let Thomas decide the actual priority, but for me this is a low priority at the moment.

ThomasRaoux commented 2 years ago

This is slightly different than what Okwan is doing as the idea was to fuse transpose with consumers. Right now we decided not to do it. I think we will eventually still want some support for it but probably not for a while. I agree I think it should be low priority at this point.

ThomasRaoux commented 2 years ago

@nirvedhmeshram, I'm removing you from assignee since I don't expect you to work on it any time soon