Closed okkwon closed 2 years ago
To clarify, this just means we should be fusing those dispatches together such that they are accessing the source data once for all 3 reductions, correct?
("pattern matching ML layers" is precisely what we don't want to do, so just want to make sure the terminology is right here that it's a fusion opportunity at the linalg level, where multiple linalg ops consuming the same data should be fused into a single dispatch)
Yes, I carefully chose "investigate a fusion opportunity" instead of "do pattern matching" according to our design principles. ;) Thanks Ben!
Cool! This will be a really great improvement!
@okkwon, you mentioned this was dependent on some in-progress work from @MaheshRavishankar. Is that tracked with an issue yet?
@benvanik, yep, we have been looking at ONNX fusions for inspiration and empirical data, but are only considering opportunities that have meaningful impact and are generalizable at the linalg level.
@MaheshRavishankar is working on https://github.com/google/iree/pull/8970. But, we need a vertical fusion of unrelated operations (some calls it horizontal fusion) to fuse reductions.
Thats the PR, it needs some upstream fixes. I sent a patch out for that already. Have to work on one more.
I am not sure we need "horizontal fusion". The mean
is used in variance
, so there is a dependence.
It might be related to this (I havent tried layer norm by itself), but softmax has similar access properties. I tried to look into what would be needed to get all of softmax fused into one dispatch. Here is the IR for softmax after elementwise fusion today
func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
%0 = linalg.init_tensor [6] : tensor<6xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%cst_1 : tensor<6x3xf32>) outs(%1 : tensor<6xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%9 = arith.maxf %arg1, %arg2 : f32
linalg.yield %9 : f32
} -> tensor<6xf32>
%3 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
%4 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%cst_1, %2 : tensor<6x3xf32>, tensor<6xf32>)
outs(%3 : tensor<6x3xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%9 = arith.subf %arg1, %arg2 : f32
%10 = math.exp %9 : f32
linalg.yield %10 : f32
} -> tensor<6x3xf32>
%5 = linalg.fill ins(%cst : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%4 : tensor<6x3xf32>) outs(%5 : tensor<6xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%9 = arith.addf %arg1, %arg2 : f32
linalg.yield %9 : f32
} -> tensor<6xf32>
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%4, %6 : tensor<6x3xf32>, tensor<6xf32>)
outs(%3 : tensor<6x3xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%9 = arith.divf %arg1, %arg2 : f32
linalg.yield %9 : f32
} -> tensor<6x3xf32>
%8 = hal.tensor.export %7 : tensor<6x3xf32> -> !hal.buffer_view
return %8 : !hal.buffer_view
}
They are not fused today since %4
has two uses. When #9303 is addressed the IR would be
func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
%0 = linalg.init_tensor [6] : tensor<6xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%cst_1 : tensor<6x3xf32>) outs(%1 : tensor<6xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%9 = arith.maxf %arg1, %arg2 : f32
linalg.yield %9 : f32
} -> tensor<6xf32>
%3 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
%5 = linalg.fill ins(%cst : f32) outs(%0 : tensor<6xf32>) -> tensor<6xf32>
%6:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "reduction"]}
ins(%cst_1, %2 : tensor<6x3xf32>, tensor<6xf32>) outs(%5, %3 : tensor<6xf32>, tensor<6x3xf32>) {
^bb0(%arg1: f32, %arg2 : f32, %arg3: f32, %arg4 : f32):
%9 = arith.subf %arg1, %arg2 : f32
%10 = math.exp %9 : f32
%11 = arith.addf %10, %arg3 : f32
linalg.yield %11, %10 : f32, f32
} -> (tensor<6xf32>, tensor<6x3xf32>)
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%6#1, %6#0 : tensor<6x3xf32>, tensor<6xf32>)
outs(%3 : tensor<6x3xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%9 = arith.divf %arg1, %arg2 : f32
linalg.yield %9 : f32
} -> tensor<6x3xf32>
%8 = hal.tensor.export %7 : tensor<6x3xf32> -> !hal.buffer_view
return %8 : !hal.buffer_view
}
(Note that it fails the verifier in Linalg op since the verifier is being too constrained. The tip of this branch fixes the issue with the verifier)
Currently dispatch region formation does not fuse LinalgOp
with producers of its ins
operands. This is a simple choice to avoid fusing GEMMs with the operands (which is undesirable). But it is valid to fuse a generic op with its producer when the parallelism in the producer and consumer are the same. Their might be a more robust way of doing this, but one way of doing this is to just check the indexing map used to access the fused value in the producer and the consumer. If they match, then the consumer can be fused with the producer. This is implemented in the commit before the tip of this branch. With this the first two generic ops can be fused. I think this change makes sense in general and can be made part of the default heuristics in dispatch region formation.
Now we are down to two dispatches for the softmax operation.
Current dispatch region heuristics only fuses a producer with its consumer only if the amount of parallelism in the consumer and producer are the same. To reduce the number of dispatches, it might be valid to pessimize the parallelism in the consumer to match the producer (as would be needed here). One way to do this is to check that the producer and consumer use the same expressions to access the value (you cant check direct equality here since the iteration space of the producer and consumer are different in cases like gemm + bias-add fusion). This is implemented in the top of this branch. This along with some fixes to handle multi-result root operation gets the entire softmax into a single dispatch. I think this commit needs to be experimented with before landing. (Note for this step I also used the -iree-flow-enable-multi-result-dispatches
and -iree-flow-ensure-inplacable-consumer=false
to sidestep some unrelated issues, but all this needs to be flushed out).
With these changes, the softmax gets fused into a single dispatch
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
module {
func.func @compute_simulated(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c6 = arith.constant 6 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<6x3xf32>
%0 = flow.tensor.splat %cst : tensor<6xf32>
%1 = flow.dispatch.workgroups[%c3, %c6, %c1](%cst_0, %0) : (tensor<6x3xf32>, tensor<6xf32>) -> tensor<6x3xf32> =
(%arg1: !flow.dispatch.tensor<readonly:6x3xf32>, %arg2: !flow.dispatch.tensor<readonly:6xf32>, %arg3: !flow.dispatch.tensor<writeonly:6x3xf32>) {
%cst_1 = arith.constant 0.000000e+00 : f32
%3 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:6x3xf32> -> tensor<6x3xf32>
%4 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [6], strides = [1] : !flow.dispatch.tensor<readonly:6xf32> -> tensor<6xf32>
%5 = linalg.init_tensor [6, 3] : tensor<6x3xf32>
%6 = linalg.init_tensor [6] : tensor<6xf32>
%7 = linalg.fill ins(%cst_1 : f32) outs(%6 : tensor<6xf32>) -> tensor<6xf32>
%8 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%3 : tensor<6x3xf32>) outs(%4 : tensor<6xf32>) {
^bb0(%arg4: f32, %arg5: f32):
%11 = arith.maxf %arg4, %arg5 : f32
linalg.yield %11 : f32
} -> tensor<6xf32>
%9:2 = linalg.generic {indexing_maps = [#map0, #map1, #map1, #map0], iterator_types = ["parallel", "reduction"]} ins(%3, %8 : tensor<6x3xf32>, tensor<6xf32>) outs(%7, %5 : tensor<6xf32>, tensor<6x3xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):
%11 = arith.subf %arg4, %arg5 : f32
%12 = math.exp %11 : f32
%13 = arith.addf %12, %arg6 : f32
linalg.yield %13, %12 : f32, f32
} -> (tensor<6xf32>, tensor<6x3xf32>)
%10 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%9#1, %9#0 : tensor<6x3xf32>, tensor<6xf32>) outs(%5 : tensor<6x3xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%11 = arith.divf %arg4, %arg5 : f32
linalg.yield %11 : f32
} -> tensor<6x3xf32>
flow.dispatch.tensor.store %10, %arg3, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : tensor<6x3xf32> -> !flow.dispatch.tensor<writeonly:6x3xf32>
flow.return
}
%2 = hal.tensor.export %1 : tensor<6x3xf32> -> !hal.buffer_view
return %2 : !hal.buffer_view
}
}
There is some details to be worked out, but this was just a straw man evaluation of changes that would enable softmax fusion into a single dispatch. Fall out from each of these changes need to be evaluated and landed, as well as ensuring that backends can handle such dispatch regions. But this mainly to check how far we are from having dispatch region formation automatically allowing for fusing operations from softmax into a single dispatch (took me just a couple of hours to hash this out after I had a unit test for softmax). I suspect the same things will be needed for layer norm as well (i will have to get a IR unit test for layer norm to prove this, but I am fairly confident of it).
I am not planning to push on this in the next few weeks. Just putting the patches up if someone is interested in pushing this forward. If not I will probably circle back to this in a few weeks.
FYI @ThomasRaoux .
Really great breakdown Mahesh! These kind of patterns show up a lot and each of those steps are useful across a large number of situations (especially those open-coded in frameworks like torch instead of using a fixed op set). Those looking at fusion opportunities should see if what they're doing can be decomposed similarly - I suspect there will be overlap and these nice isolated changes will take care of things in a scalable way.
The work is being done with two parts, which is a little bit different from the approach above.
The first part is being done at https://github.com/iree-org/iree/issues/9523.
Unlike a layer normalization, Softmax does not have a dependency to the original input.
Softmax
LayerNom
A PR for the first part is under review. https://github.com/iree-org/iree/pull/9693
There is an ongoing discussion about how to solve the general fusion problems. No ETA yet.
From today's meeting: assigning to @MaheshRavishankar
Adding @okkwon to this issue in case there have been any updates while Mahesh is out.
This is on my plate. The resolution is about 2-3 weeks out.
From meeting 8/25: Draft PR in place, still looking into backend issues
This is implemented at HEAD but is guarded by the flag --iree-flow-enable-aggressive-fusion
pass. This is disabled by default now due to backend issue. So the work for this bug itself is done.
ONNX pattern-matches multiple layer normalization patterns with/without bias and skip data, and uses a single CUDA kernel for it. On the IREE side, we have three dispatches: mean (reduction), variance (reduction), and normalizations (element-wise operations).
ONNX got ~7% perf improvement by pattern matching 23 layer normalization instances from miniLM.