iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 611 forks source link

Investigate a fusion opportunity on Layer Normalization #9139

Closed okkwon closed 2 years ago

okkwon commented 2 years ago

ONNX pattern-matches multiple layer normalization patterns with/without bias and skip data, and uses a single CUDA kernel for it. On the IREE side, we have three dispatches: mean (reduction), variance (reduction), and normalizations (element-wise operations).

ONNX got ~7% perf improvement by pattern matching 23 layer normalization instances from miniLM.

benvanik commented 2 years ago

To clarify, this just means we should be fusing those dispatches together such that they are accessing the source data once for all 3 reductions, correct?

("pattern matching ML layers" is precisely what we don't want to do, so just want to make sure the terminology is right here that it's a fusion opportunity at the linalg level, where multiple linalg ops consuming the same data should be fused into a single dispatch)

okkwon commented 2 years ago

Yes, I carefully chose "investigate a fusion opportunity" instead of "do pattern matching" according to our design principles. ;) Thanks Ben!

benvanik commented 2 years ago

Cool! This will be a really great improvement!

julianwa commented 2 years ago

@okkwon, you mentioned this was dependent on some in-progress work from @MaheshRavishankar. Is that tracked with an issue yet?

@benvanik, yep, we have been looking at ONNX fusions for inspiration and empirical data, but are only considering opportunities that have meaningful impact and are generalizable at the linalg level.

okkwon commented 2 years ago

@MaheshRavishankar is working on https://github.com/google/iree/pull/8970. But, we need a vertical fusion of unrelated operations (some calls it horizontal fusion) to fuse reductions.

MaheshRavishankar commented 2 years ago

Thats the PR, it needs some upstream fixes. I sent a patch out for that already. Have to work on one more.

I am not sure we need "horizontal fusion". The mean is used in variance, so there is a dependence.

MaheshRavishankar commented 2 years ago

It might be related to this (I havent tried layer norm by itself), but softmax has similar access properties. I tried to look into what would be needed to get all of softmax fused into one dispatch. Here is the IR for softmax after elementwise fusion today

func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 0xFF800000 : f32
  %cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
  %0 = linalg.init_tensor [6] : tensor<6xf32>
  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
  %2 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%cst_1 : tensor<6x3xf32>) outs(%1 : tensor<6xf32>) {
    ^bb0(%arg1: f32, %arg2: f32):
      %9 = arith.maxf %arg1, %arg2 : f32
      linalg.yield %9 : f32
    } -> tensor<6xf32>
  %3 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
  %4 = linalg.generic {
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
       ins(%cst_1, %2 : tensor<6x3xf32>, tensor<6xf32>)
       outs(%3 : tensor<6x3xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
      %9 = arith.subf %arg1, %arg2 : f32
      %10 = math.exp %9 : f32
      linalg.yield %10 : f32
    } -> tensor<6x3xf32>
  %5 = linalg.fill ins(%cst : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
  %6 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%4 : tensor<6x3xf32>) outs(%5 : tensor<6xf32>) {
    ^bb0(%arg1: f32, %arg2: f32):
      %9 = arith.addf %arg1, %arg2 : f32
      linalg.yield %9 : f32
    } -> tensor<6xf32>
  %7 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%4, %6 : tensor<6x3xf32>, tensor<6xf32>)
      outs(%3 : tensor<6x3xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
      %9 = arith.divf %arg1, %arg2 : f32
      linalg.yield %9 : f32
    } -> tensor<6x3xf32>
  %8 = hal.tensor.export %7 : tensor<6x3xf32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

Step 1 is to fuse the middle two generic ops.

They are not fused today since %4 has two uses. When #9303 is addressed the IR would be

func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 0xFF800000 : f32
  %cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
  %0 = linalg.init_tensor [6] : tensor<6xf32>
  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
  %2 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%cst_1 : tensor<6x3xf32>) outs(%1 : tensor<6xf32>) {
    ^bb0(%arg1: f32, %arg2: f32):
      %9 = arith.maxf %arg1, %arg2 : f32
      linalg.yield %9 : f32
    } -> tensor<6xf32>
  %3 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
  %5 = linalg.fill ins(%cst : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
  %6:2 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%cst_1, %2 : tensor<6x3xf32>, tensor<6xf32>) outs(%5, %3 : tensor<6xf32>, tensor<6x3xf32>) {
    ^bb0(%arg1: f32, %arg2 : f32, %arg3: f32, %arg4 : f32):
      %9 = arith.subf %arg1, %arg2 : f32
      %10 = math.exp %9 : f32
      %11 = arith.addf %10, %arg3 : f32
      linalg.yield %11, %10 : f32, f32
    } -> (tensor<6xf32>, tensor<6x3xf32>)
  %7 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%6#1, %6#0 : tensor<6x3xf32>, tensor<6xf32>)
      outs(%3 : tensor<6x3xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
      %9 = arith.divf %arg1, %arg2 : f32
      linalg.yield %9 : f32
    } -> tensor<6x3xf32>
  %8 = hal.tensor.export %7 : tensor<6x3xf32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

(Note that it fails the verifier in Linalg op since the verifier is being too constrained. The tip of this branch fixes the issue with the verifier)

Step 2 Add fusion of root operation with producer if the parallel iteration space matches.

Currently dispatch region formation does not fuse LinalgOp with producers of its ins operands. This is a simple choice to avoid fusing GEMMs with the operands (which is undesirable). But it is valid to fuse a generic op with its producer when the parallelism in the producer and consumer are the same. Their might be a more robust way of doing this, but one way of doing this is to just check the indexing map used to access the fused value in the producer and the consumer. If they match, then the consumer can be fused with the producer. This is implemented in the commit before the tip of this branch. With this the first two generic ops can be fused. I think this change makes sense in general and can be made part of the default heuristics in dispatch region formation.

Now we are down to two dispatches for the softmax operation.

Step 3 Fuse the root with consumer even if it means the pessimization of the parallelism in the consumer.

Current dispatch region heuristics only fuses a producer with its consumer only if the amount of parallelism in the consumer and producer are the same. To reduce the number of dispatches, it might be valid to pessimize the parallelism in the consumer to match the producer (as would be needed here). One way to do this is to check that the producer and consumer use the same expressions to access the value (you cant check direct equality here since the iteration space of the producer and consumer are different in cases like gemm + bias-add fusion). This is implemented in the top of this branch. This along with some fixes to handle multi-result root operation gets the entire softmax into a single dispatch. I think this commit needs to be experimented with before landing. (Note for this step I also used the -iree-flow-enable-multi-result-dispatches and -iree-flow-ensure-inplacable-consumer=false to sidestep some unrelated issues, but all this needs to be flushed out).

With these changes, the softmax gets fused into a single dispatch

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
  func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %c6 = arith.constant 6 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %cst = arith.constant 0xFF800000 : f32
    %cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
    %0 = flow.tensor.splat %cst : tensor<6xf32>
    %1 = flow.dispatch.workgroups[%c3, %c6, %c1](%cst_0, %0) : (tensor<6x3xf32>, tensor<6xf32>) -> tensor<6x3xf32> =
        (%arg1: !flow.dispatch.tensor<readonly:6x3xf32>, %arg2: !flow.dispatch.tensor<readonly:6xf32>, %arg3: !flow.dispatch.tensor<writeonly:6x3xf32>) {
      %cst_1 = arith.constant 0.000000e+00 : f32
      %3 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:6x3xf32> -> tensor<6x3xf32>
      %4 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor<readonly:6xf32> -> tensor<6xf32>
      %5 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
      %6 = linalg.init_tensor [6] : tensor<6xf32>
      %7 = linalg.fill ins(%cst_1 : f32) outs(%6 : tensor<6xf32>) -> tensor<6xf32>
      %8 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%3 : tensor<6x3xf32>) outs(%4 : tensor<6xf32>) {
      ^bb0(%arg4: f32, %arg5: f32):
        %11 = arith.maxf %arg4, %arg5 : f32
        linalg.yield %11 : f32
      } -> tensor<6xf32>
      %9:2 = linalg.generic {indexing_maps = [#map0, #map1, #map1, #map0], iterator_types = ["parallel", "reduction"]} ins(%3, %8 : tensor<6x3xf32>, tensor<6xf32>) outs(%7, %5 : tensor<6xf32>, tensor<6x3xf32>) {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):
        %11 = arith.subf %arg4, %arg5 : f32
        %12 = math.exp %11 : f32
        %13 = arith.addf %12, %arg6 : f32
        linalg.yield %13, %12 : f32, f32
      } -> (tensor<6xf32>, tensor<6x3xf32>)
      %10 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%9#1, %9#0 : tensor<6x3xf32>, tensor<6xf32>) outs(%5 : tensor<6x3xf32>) {
      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
        %11 = arith.divf %arg4, %arg5 : f32
        linalg.yield %11 : f32
      } -> tensor<6x3xf32>
      flow.dispatch.tensor.store %10, %arg3, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : tensor<6x3xf32> -> !flow.dispatch.tensor<writeonly:6x3xf32>
      flow.return
    }
    %2 = hal.tensor.export %1 : tensor<6x3xf32> -> !hal.buffer_view
    return %2 : !hal.buffer_view
  }
}

There is some details to be worked out, but this was just a straw man evaluation of changes that would enable softmax fusion into a single dispatch. Fall out from each of these changes need to be evaluated and landed, as well as ensuring that backends can handle such dispatch regions. But this mainly to check how far we are from having dispatch region formation automatically allowing for fusing operations from softmax into a single dispatch (took me just a couple of hours to hash this out after I had a unit test for softmax). I suspect the same things will be needed for layer norm as well (i will have to get a IR unit test for layer norm to prove this, but I am fairly confident of it).

I am not planning to push on this in the next few weeks. Just putting the patches up if someone is interested in pushing this forward. If not I will probably circle back to this in a few weeks.

FYI @ThomasRaoux .

benvanik commented 2 years ago

Really great breakdown Mahesh! These kind of patterns show up a lot and each of those steps are useful across a large number of situations (especially those open-coded in frameworks like torch instead of using a fixed op set). Those looking at fusion opportunities should see if what they're doing can be decomposed similarly - I suspect there will be overlap and these nice isolated changes will take care of things in a scalable way.

okkwon commented 2 years ago

Image

The work is being done with two parts, which is a little bit different from the approach above.

  1. fuse reduction + elementwise: There are two groups of fusions (dispatch 10 & 11) and (dispatch 12 and 13).
  2. fuse the two groups again: There are two uses from the first group to the second group, but the second group fully consumes both uses.
okkwon commented 2 years ago

The first part is being done at https://github.com/iree-org/iree/issues/9523.

okkwon commented 2 years ago

Unlike a layer normalization, Softmax does not have a dependency to the original input.

Softmax

Image

LayerNom

Image

okkwon commented 2 years ago

A PR for the first part is under review. https://github.com/iree-org/iree/pull/9693

okkwon commented 2 years ago

There is an ongoing discussion about how to solve the general fusion problems. No ETA yet.

allieculp commented 2 years ago

From today's meeting: assigning to @MaheshRavishankar

allieculp commented 2 years ago

Adding @okkwon to this issue in case there have been any updates while Mahesh is out.

MaheshRavishankar commented 2 years ago

This is on my plate. The resolution is about 2-3 weeks out.

allieculp commented 2 years ago

From meeting 8/25: Draft PR in place, still looking into backend issues

MaheshRavishankar commented 2 years ago

This is implemented at HEAD but is guarded by the flag --iree-flow-enable-aggressive-fusion pass. This is disabled by default now due to backend issue. So the work for this bug itself is done.