iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.47k stars 548 forks source link

Handling reshape propagation for attention ops. #17673

Open MaheshRavishankar opened 1 week ago

MaheshRavishankar commented 1 week ago

Recently we saw a use case for propagating reshapes across attention ops the same way we propagate reshapes across Linalg ops. For now we added a one-off folder pattern (https://github.com/iree-org/iree/commit/d2ca77402becf4c6476893845ba96116b61df9c1) that mimics the end-state, but we should be able to reuse some of the same techniques as we have for Linalg ops.

To provide some context, this is the input IR that we are looking at

%attention = iree_linalg_ext.attention {
    indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
    ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
    outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
  %split = arith.divsi %d0, %c2 : index
  %expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
      : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>

If we increase the dimensionality of the attention op, we could make the attention op generate the expanded output shape directly to get

%expanded_arg0 = tensor.expand_shape %arg0 [[0, 1] ,[2], [3]] ...
%expanded_arg1 = tensor.expand_shape %arg1[[0, 1], [2], [3]] ...
%expanded_arg2 = tensor.expand_shape %arg2[[0, 1], [2], [3]] ...
%expanded_empty = tensor.expand_shape %empty [[0, 1], [2], [3]]
%expanded_attention = iree_linalg_ext.attention {
    indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, d2)>,
                     affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, d2)>,
                     affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, d4)>,
                     affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, d4)>]}
    ins(%expanded_arg0, %expanded_arg1, %expanded_arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
    outs(%expanded_empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>

This is essentially similar to what is done in the foldReshapeByExpansion transformation on Linalg ops.

For now we can borrow a lot from the implementation there, and essentially replicate this in IREE to be able to apply it to LinalgExt ops. (The pie-in-the-sky goals of LinalgExt ops is to move these into MLIR, but thats for a later time). The load bearing piece in the implementation for LinalgOps is the ExpansionInfo. It takes the reassociation maps of the consumer expand_shape operation (as well as the source collapsed and expanded shapes). This information is then used to compute (a) if the op is expandable (here) (b) the indexing map in the expanded op for every indexing map in the original op (here) (c) The type of the operand in the expanded op for the type of the operand in the original op and indexing map used to access that operand (here) (d) The reassociation indices to be used for the expand_shape that has to be generated with the original operands of the attention op as source (here)

This logic needs to be replicated in IREE (for now) and used to generate the expanded attention op the same way the LinalgOp is expanded to higher dimensions here

Once this expansion is done it unlocks more fusion opportunities. For example, after the reshape is propagated "up" through the attention op it can then more easily fuse with the transpose operation here

MaheshRavishankar commented 1 week ago

cc @Groverkss FYI since you were interested.