Open Max191 opened 1 month ago
Regarding matching generic
to mmt4d
ukernels:
This can be viewed as 2 separate steps:
linalg.generic
to linalg.mmt4d
. This is the new code that needs to be written. It can be a separate pass or pattern, independent of ukernels.So all you need to do here is that step 1.
There is a caveat about what "support" means. The mmt4d
ukernel supports arbitrary linalg.mmt4d
ops shapes, for the supported combinations of element types, in the sense that the lowering always succeeds and there is always an actual ukernel to handle any case. The caveat is about performance: that wide support is achieved thanks to a slow fallback (which is slower than any generic codegen), so there is a performance cliff here, that one can silently fall off of.
The worst-case scenario then is that implementing the linalg.generic
to linalg.mmt4d
rewrite causes a workload to hit the slow ukernel path. However, that problem preexists with workloads that use linalg.mmt4d
ops. The reason why we haven't prioritized solving it is a combination of: it's not been such a big deal (because we have added ukernel fast paths as needed) and it's not trivial to fix because this requires enlarging the contact surface between ukernels and the compiler. However, if we are willing to do that, then the solution is a simple enough PR: https://github.com/iree-org/iree/pull/16880 . GitHub shows an outdated diff base for that PR, but most of that diff is already in main
as https://github.com/iree-org/iree/pull/16879 was merged: we have already structured the ukernels code to have a declarative table enumerating fast code paths, so what is left to do in that PR is a simple matter of exposing that to the compiler.
There is a caveat about what "support" means. The
mmt4d
ukernel supports arbitrarylinalg.mmt4d
ops shapes, for the supported combinations of element types, in the sense that the lowering always succeeds and there is always an actual ukernel to handle any case. The caveat is about performance: that wide support is achieved thanks to a slow fallback (which is slower than any generic codegen), so there is a performance cliff here, that one can silently fall off of.
I see, thanks for the context. I think it makes sense to start by unconditionally converting linalg.generic to mmt4d. It would be best to complicate the current next steps as little as we can. We can leave it as a potential TODO in case we hit any future performance pitfalls in the wild.
Overview
This issue is a tracker for the next steps of enabling the data tiling fusion improvements, which are currently behind the
--iree-dispatch-creation-experimental-data-tiling
flag. Once the steps here are complete, there should be fairly reasonable support for general grouped quantized matmul operations through data tiling.The main steps are as follows:
1. SetEncoding
This change should be very simple. SetEncoding should not behave any differently with multiple batch, M, N, or K dimensions. We just need to remove the restriction in the pass for single batch/M/N/K dimensions.
2. MaterializeEncoding
This is a more involved step, since materializing a multi batch/M/N/K dimension contraction cannot create an MMT4D op. Instead, we need to create a linalg.generic op with the appropriate iterator_types and indexing_maps. There are useful functions for checking if generic ops are contractions, and getting the appropriate dimension mapping to batch, M, N, and K dimensions here:
https://github.com/llvm/llvm-project/blob/030c6da7af826b641db005be925b20f956c3a6bb/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h#L50-L71
Here is an example of how this is used to find the narrow M and N sizes of a matmul:
https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp#L56-L67
I have written some IR for what I would expect the final transformation to look like below. Note that the only materialization pattern that should be needed is the one for the final contraction operation. The patterns for set_encoding, unset_encoding, and dequant ops should already be supported. I only included them in this example so it is clear where this all comes from.
3. Lowering Path
The next thing we need is a lowering path for the resulting data tiled matmul:
This primarily means setting an appropriate lowering configuration for the op, but there will likely be issues to solve in whatever lowering pipeline this goes down. There are a few changes to be made for lowering configurations:
linalg::isaContractionOpInterface
returns true. There is something similar done in GPU configurations: https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp#L854-L864 In general, I think contraction ops with multiple M, N or K dimensions should have all outer M/N/K dimensions tiled to1
as an initial strategy.TileRootFuseConsumerProducer
pipeline. There may be some issues along the way to solve as more dispatches are tested on the new pipeline.Lower to UKernels
UKernels are rooted on MMT4D ops right now, and we don't want to lose the benefits of ukernels when we materialized generic ops instead of mmt4d. The
CPULowerToUKernels
pass converts a set of named operations into UKernel ops based on some matching patterns called here: https://github.com/iree-org/iree/blob/febe0edba1ea4b29135d23d64d84417b7a106be8/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp#L584-L585 A newmatchDAGForUKernel
matcher will be needed for linalg.generic ops. The new matcher will need to check the shapes of the generic op inputs against the supported ukernel shapes, since it is not guaranteed that an arbitrary generic op will match the shape of a ukernel. Would be good to get @bjacob's input on this, since matching against all possible ukernel shapes seems a little cumbersome.EDIT: See Benoit's comment below. There should be no need to match against specific shapes for now. Worst case will be bad performance, but that can be something to look into later.
Another detail is that we will most likely want to do some unit dim folding after tiling to get rid of outer unit dimensions on the contractions. This will make matching ukernels much more simple.
Final Comments
The sequence of steps in the overview are listed in some order, but that is not a strict dependency order. Particularly, steps
2.
and3.
could be done in either order, since I have provided some example IR after materialization. Doing3.
first could provide some lessons for the particular form we want as a result of2.
in general. Although3.
is dependent on2.
for e2e enablement, they could be worked on in parallel.