iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.57k stars 576 forks source link

Conversion of einsum-like operations into matmul-like operations. #13528

Open MaheshRavishankar opened 1 year ago

MaheshRavishankar commented 1 year ago

For getting reasonable performance on current code-generation paths, einsum-like operations need to be converted into a named matmul-like operations.

Since other front-ends might do the same we might need to do this transformation on Linalg itself. This is to document some thoughts on how this could be done at Linalg level. Here the description is considering lowering einsum like operations into batch matmul (which is the most general case of all matmul-like operations).

As an example consider the following einsum operations

#map_lhs = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3, d2, d7)>
#map_rhs = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d4, d3, d6, d2, d7)>
#map_out = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4, d5, d6)>
%0 = linalg.generic {
    indexing_maps = [#map_lhs, #map_rhs, #map_out, #map_out]
    iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "reduction"]}
    ins(%lhs, %rhs : tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
    outs(%init : tensor<?x?x?x?x?x?xf32>) {
  ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
    %1 = arith.mulf %b0, %b1 : f32
    %2 = arith.addf %b2, %1 : f32
    linalg.yield %2 : f32
  } -> tensor<?x?x?x?x?x?xf32>

The first thing to do is characterize the dimensions. Batch matmul has four types of dimensions

Multiple dimensions of the original op collapse into these dimensions of the final batch matmul. The characterization of the dimensions of the original op can be done by using these rules

Further constraints is the body of the operation has to be multiply + add.

A further generalization of these constraints is needed to handle when one or more operands is broadcasted. To handle those the check for the dimensions need to happen in this order

The key is

Once we have a classification of the dimensions we need to introduce transpose to get the order as follows

The reshapes will collapse each of the sets above into a single dimension (each).

Adding this pattern to Linalg (could be done in MLIR) will allow deduplicate the patterns added in #13519 and #13468.

MaheshRavishankar commented 1 year ago

@allieculp this is the description for @NatashaKnk to make progress on einsum -> batch_matmul conversion.. Please add this to appropriate sprints.

MaheshRavishankar commented 1 year ago

cc @silvasean and @rsuderman to verify my logic above.

silvasean commented 1 year ago

This seems right to me.

I wonder if part of this issue should be to generalize all linalg.matmul/linalg.batch_matmul, let them fuse as linalg.generic with reshapes/transposes around them (and possibly other stuff, like reductions), and then finally canonicalize them into linalg.matmul/batch_matmul. Do you think that is useful or in scope? I feel it could give us more performance stability across different ways for users to write the same thing. Cases like https://github.com/openxla/iree/issues/12214 have open-coded broadcast + batch_matmul which could possibly benefit from being handled in the same way as the corresponding einsum, if it were written by the user (in that case the broadcast + batch_matmul is equivalent to a regular matmul with a larger LHS preserved dimension)