iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.58k stars 579 forks source link

Fusion opportunities for matmul + reshape + elementwise ops in MiniLM bert model #9278

Closed hanhanW closed 4 months ago

hanhanW commented 2 years ago

I found a new case that we might want to fuse them into a single dispatch. The result of matmul is only used by flow.tensor.reshape, and the result of flow.tensor.reshape is only used by the element-wise op. I think we can reorder it to matmul -> elementwise -> reshape, then we can fuse matmul + elementwise into a single dispatch.

#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map6 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d0, d2)>
module {
  flow.executable private @predict_dispatch_6 {
    flow.dispatch.entry public @predict_dispatch_6 attributes {workgroup_rank = 3 : index}
    builtin.module {
      func.func @predict_dispatch_6(%arg0: !flow.dispatch.tensor<readonly:128x384xf32>, %arg1: !flow.dispatch.tensor<readonly:384x384xf32>, %arg2: !flow.dispatch.tensor<writeonly:128x384xf32>) {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xf32> -> tensor<128x384xf32>
        %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [384, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x384xf32> -> tensor<384x384xf32>
        %2 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
        %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x384xf32>) -> tensor<128x384xf32>
        %4 = linalg.matmul ins(%0, %1 : tensor<128x384xf32>, tensor<384x384xf32>) outs(%3 : tensor<128x384xf32>) -> tensor<128x384xf32>
        flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : tensor<128x384xf32> -> !flow.dispatch.tensor<writeonly:128x384xf32>
        return
      }
    }
  }
  flow.executable private @predict_dispatch_7 {
    flow.dispatch.entry public @predict_dispatch_7 attributes {workgroup_rank = 3 : index}
    builtin.module {
      func.func @predict_dispatch_7(%arg0: !flow.dispatch.tensor<readonly:128x12x32xf32>, %arg1: !flow.dispatch.tensor<readonly:12x32xf32>, %arg2: !flow.dispatch.tensor<writeonly:12x128x32xf32>) {
        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [128, 12, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:128x12x32xf32> -> tensor<128x12x32xf32>
        %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [12, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:12x32xf32> -> tensor<12x32xf32>
        %2 = linalg.init_tensor [12, 128, 32] : tensor<12x128x32xf32>
        %3 = linalg.generic {indexing_maps = [#map6, #map7, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %1 : tensor<128x12x32xf32>, tensor<12x32xf32>) outs(%2 : tensor<12x128x32xf32>) {
        ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
          %4 = arith.addf %arg3, %arg4 : f32
          linalg.yield %4 : f32
        } -> tensor<12x128x32xf32>
        flow.dispatch.tensor.store %3, %arg2, offsets = [0, 0, 0], sizes = [12, 128, 32], strides = [1, 1, 1] : tensor<12x128x32xf32> -> !flow.dispatch.tensor<writeonly:12x128x32xf32>
        return
      }
    }
  }

  func.func @predict(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) {
    ...
    %10 = flow.dispatch @predict_dispatch_6::@predict_dispatch_6[%c384, %c128, %c1](%8, %cst_32) : (tensor<128x384xf32>, tensor<384x384xf32>) -> tensor<128x384xf32>
    %11 = flow.tensor.reshape %10 : tensor<128x384xf32> -> tensor<128x12x32xf32>
    %12 = flow.dispatch @predict_dispatch_7::@predict_dispatch_7[%c32, %c128, %c12](%11, %cst_0) : (tensor<128x12x32xf32>, tensor<12x32xf32>) -> tensor<12x128x32xf32>
    ...
  }
}
MaheshRavishankar commented 2 years ago

Its hard to see at this level. The elementwise op fusion is supposed to do the reshape -> elementwise op -> elementwise op -> reshape transformation. Need to look at the whole model though to see why this didnt happen. IR before elementwise fusion would be the place to start. @hanhanW do you have that handy?

hanhanW commented 2 years ago

Yes, I have the IR before and after elementwise fusion: https://gist.githubusercontent.com/hanhanW/0df90c3751be3df5ce59515c36d3ad79/raw

okkwon commented 2 years ago
image
MaheshRavishankar commented 2 years ago

https://github.com/google/iree/blob/a6b96c1706aac6e18493a268636420ee875cca67/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp#L152 is meant to handle this case.

okkwon commented 2 years ago

Looking into it

okkwon commented 2 years ago

The folding didn't happen because the indexing map for the reshaped input has a transpose. The folding has a logic that checks if the dim sequence is preserved or not, which seems to be too conservative.

MaheshRavishankar commented 2 years ago

Ok, thanks for looking into it. I dont think the requirement is conservative. Here is the IR I think before dispatch region formation

%11 = tensor.expand_shape [[0], [1, 2]] : tensor<128x384xf32> -> tensor<128x12x32xf32>
%12 = linalg.generic {
    indexing_maps = [
        affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 
        affine_map<(d0, d1, d2) -> (d0, d2)>,
        affine_map<(d0, d1, d2) -> (d0, d1, d2)>] ...} ins(%11 : tensor<128x12x32xf32>) ....

One of the things that needs to be done carefully to fuse the reshape with its producer (by collapsing the dimensions of the consumer op) is that you want to keep the indexing maps of the collapsed op as permutations as well. In this case, based on the indexing map for the operand that is to be fused, the dimensions d0 and d2 of the consumer linalg.generic are to be collapsed. But the output is computed using the indexing map affine_map<(d0, d1, d2) -> (d0, d1, d2)> which doesnt have d0 and d2 as "consecutive" dimensions. In such cases the folded op will have to use mods and divs to recreate the value of d0 and d2 that interferes with all transformations later on. So it is a deliberate choice to not fold the reshape with the generic op.

Is this a pessimization? I think not. It might look as though we havent fused the generic op with the matmul, but really if you look at the original code the ops have been fused with other uses, so in the end the number of dispatches created is the same. If I had to guess the sequence of operations initially it would have been

%10 = linalg.matmul
%11 = linalg.generic  .. ins(%10 , .. : tensor<128x384xf32>, ...) // 2D op for bias -add
%12 = tensor.expand_shape %11 [[0], [1, 2]] : tensor<128x384xf32> -> tensor<128x12x32xf32>
%13 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2) -> (d1, d0, d2)>] ...} ... // Transpose operation

If you didnt propagate the tensor.expand_shape above the first linalg.generic you would have fuse that op with linalg.matmul, but the last linalg.generic would be in its own dispatch. Now the first linalg.generic is fused with the second linalg.generic, so still two dispatches.

(I did go deep into looking into miniLM, and I did vaguely remember something like this, but had forgotten the details).

If there is something better that can be done here, I am all for it (but I dont see it). The main constraint though is that we need to keep the indexing maps permutations.

okkwon commented 2 years ago

I am afraid that expand shape is <128x384xf32> -> <128x12x32xf32>. Does this change anything in your comment? The expand shape operation itself does nothing in the data if they are laid out contiguously, which is the case I believe.

What is the cost to have %11 as a separate node as shown in the graph? Shouldn't it be a no-op?

MaheshRavishankar commented 2 years ago

I am afraid that expand shape is <128x384xf32> -> <128x12x32xf32>. Does this change anything in your comment? The expand shape operation itself does nothing in the data if they are laid out contiguously, which is the case I believe.

Sorry, my bad. Fixed the comment, but the core of it holds. The reshape itself is just changing metadata, but when you try to move the expand "past the" generic op, you have to collapse the generic op to a 2D, but doing that will introduce mods and divs in the indexing maps, which is not allowed/not good for Linalg ops.

What is the cost to have %11 as a separate node as shown in the graph? Shouldn't it be a no-op?

The node in the graph itself is a no-op.

okkwon commented 2 years ago

Thanks. It seems like a low priority then, unless we have a compelling reason to fuse it into a dispatch region; It may give more opportunity for advanced fusions such as reshaping a whole dispatch region with purely element-wise ops.

Changing the priority to 2.

hanhanW commented 4 months ago

closing stale issue