iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.49k stars 554 forks source link

Missing fusion opportunities in BERT attention layer #12214

Open silvasean opened 1 year ago

silvasean commented 1 year ago

What happened?

Here is the IR for a typical attention layer of bert-base-uncased HuggingFace model with 1 sentence, sequence length 128: gist -- see bottom for the version with full weights. (here is the program that generates it) if that is of interest.

There are some missing fusion opportunities marked in the graph below (generated with --iree-flow-dump-dispatch-graph for the IR). For this particular workload, getting these fusions should be about 33% speedup of this workload (the unfused code occupies 25% of the execution time). The main issue is that fusions are not happening into the outputs of the batch_matmul ops. There are missing biasadd fusions and an output transpose/reshape too. image

Steps to reproduce your issue

$ iree-compile --iree-hal-target-backends=cuda --iree-hal-cuda-llvm-target-arch=sm_80 /tmp/attention_layer.mlir

What component(s) does this issue relate to?

Compiler

Version information

python -m pip list| grep -E "iree|torch"
iree-compiler      20230213.429
iree-runtime       20230213.429
pytorch-triton     2.0.0+0d7e753227
torch              2.0.0.dev20230210+cu117
torch-mlir         20230211.746

Additional context

No response

ThomasRaoux commented 1 year ago

Interestingly enough the original IR looks simpler than the one we get at region formation time:

    %11 = linalg.batch_matmul ins(%3, %10 : tensor<1x128x768xf32>, tensor<1x768x768xf32>) outs(%6 : tensor<1x128x768xf32>) -> tensor<1x128x768xf32>
    %12 = linalg.generic {indexing_maps = [#map2, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%11, %cst_3 : tensor<1x128x768xf32>, tensor<768xf32>) outs(%2 : tensor<1x128x768xf32>) {
    ^bb0(%in: f32, %in_16: f32, %out: f32):
      %47 = arith.addf %in, %in_16 : f32
      linalg.yield %47 : f32
    } -> tensor<1x128x768xf32>
    %expanded = tensor.expand_shape %12 [[0], [1], [2, 3]] : tensor<1x128x768xf32> into tensor<1x128x12x64xf32>
    %13 = tensor.empty() : tensor<1x12x128x64xf32>
    %14 = linalg.generic {indexing_maps = [#map6, #map7], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x128x12x64xf32>) outs(%13 : tensor<1x12x128x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x12x128x64xf32>
    %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_0 : tensor<768x768xf32>) outs(%0 : tensor<768x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<768x768xf32>
    %16 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15 : tensor<768x768xf32>) outs(%4 : tensor<1x768x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x768x768xf32>
    %17 = linalg.batch_matmul ins(%3, %16 : tensor<1x128x768xf32>, tensor<1x768x768xf32>) outs(%6 : tensor<1x128x768xf32>) -> tensor<1x128x768xf32>

Here it is very clear that we can fuse the batchmatmul (%11) with the bias (%12). Then the transpose could be fused into the following batchmatmul to create a batchmatmul_transpose_b. Right in FusionOfTensorOpsPass we will re-order the expandShape and linalg generic because we think we can't fuse the transpose with the following matmul. Then once we re-ordered the expandShape and generic there are no more opportunity to fuse.

One simple solution would be opportunistically fuse the transpose of rhs operand in the matmul, for the GPU this tends to be a slightly better layout anyway and I suspect it is the case on all platforms.

I had filed an issue about this when we were looking at Bert performance. We got stuck at the time as we weren't sure this was the right direction but I think we should revisit it so that for those cases we fuse the transpose(b) into the matmul, I added a named op for matmul_tranpose_b to support exactly this kind of cases.

ThomasRaoux commented 1 year ago

Actually looking more in detail at the IR it is not really a matmul_transpose_b as the batch and N dimension are also swapped for B, that being said since the K dimension is still the most inner we could probably support such transposition in matmul without significant perf degradations. This means we need to generate genericOp make sure the rest of the fusions handles it well (and doesn't fuse producer ops)

The alternative would be to fuse the expand_shape in the producer matmul dispatch region.

  %5 = linalg.matmul ins(%collapsed, %cst_0 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_5 : tensor<128x768xf32>, tensor<768xf32>) outs(%1 : tensor<128x768xf32>) {
  ^bb0(%in: f32, %in_11: f32, %out: f32):
    %24 = arith.addf %in, %in_11 : f32
    linalg.yield %24 : f32
  } -> tensor<128x768xf32>
  %expanded = tensor.expand_shape %6 [[0], [1, 2]] : tensor<128x768xf32> into tensor<128x12x64xf32>
 ...
  %12 = tensor.empty() : tensor<12x64x128xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<128x12x64xf32>) outs(%12 : tensor<12x64x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x64x128xf32>
  %14 = tensor.empty() : tensor<12x128x128xf32>
  %15 = linalg.fill ins(%cst_3 : f32) outs(%14 : tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
  %16 = linalg.batch_matmul ins(%11, %13 : tensor<12x128x64xf32>, tensor<12x64x128xf32>) outs(%15 : tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
MaheshRavishankar commented 1 year ago

Sorry, I havent had time to look at it, but we are rehashing old things again... I can take a look at it, but trying to fuse the transpose with the gemm to get a gemm transpose is not what I think is the first order solution. If you want the gemm to be the transpose variant, then you should do that and then cancel out the transposes... I have to look at this case in more detail.

ThomasRaoux commented 1 year ago

Sorry, I havent had time to look at it, but we are rehashing old things again... I can take a look at it, but trying to fuse the transpose with the gemm to get a gemm transpose is not what I think is the first order solution. If you want the gemm to be the transpose variant, then you should do that and then cancel out the transposes... I have to look at this case in more detail.

Why? If both version of gemm have similar performance I don’t see why we wouldn’t want to fuse those opportunistically.

If we have another practical solution that’s fine with me too but this is something we want to fix

ThomasRaoux commented 1 year ago

Looking at the inductor trace for instance they use a mix of matmul and transposed matmul. We might have a better solution but I don't see why we would not consider this.

MaheshRavishankar commented 1 year ago

As always the case with such programs it is useful to take a full view of the program, and not zoom into small portions (before trying to devise a strategy). The issue here starts with

#map2 = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d1, d2)>
    %2 = tensor.empty() : tensor<1x12x768xf32>
    %3 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x12x768xf32>) outs(%2 : tensor<1x12x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x12x768xf32>
    %4 = tensor.empty() : tensor<1x768x768xf32>
    %5 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<768x768xf32>) outs(%4 : tensor<1x768x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x768x768xf32>
    %6 = linalg.fill ins(%cst_6 : f32) outs(%2 : tensor<1x12x768xf32>) -> tensor<1x12x768xf32>
    %7 = linalg.batch_matmul ins(%3, %5 : tensor<1x12x768xf32>, tensor<1x768x768xf32>) outs(%6 : tensor<1x12x768xf32>) -> tensor<1x12x768xf32>
    %8 = linalg.generic {indexing_maps = [#map2, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %cst_5 : tensor<1x12x768xf32>, tensor<768xf32>) outs(%2 : tensor<1x12x768xf32>) {
    ^bb0(%in: f32, %in_16: f32, %out: f32):
      %46 = arith.addf %in, %in_16 : f32
      linalg.yield %46 : f32
    } -> tensor<1x12x768xf32>

Two things to look at. The batch_matmul is actually just a matmul. There is a lot of "broadcast-like" operations here that are actually just copies if you drop the unit-dimensions. #map2 = affine_map<(d0, d1, d2) -> (0, d1, d2)> are problematic. They lead to lots of corner cases, and make analysis of indexing maps unnecessarily complicated. There is an attempt in IREE (built in MLIR) to get rid of such spurious unit-dimensions. If I do a "cleanup" of the graph to replace batch_matmul with matmul and remove the use of #map2 to a more canonical representation , i.e. start with this input

func.func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %cst = arith.constant -3.40282347E+38 : f32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %cst_1 = arith.constant dense_resource<__elided__> : tensor<768xf32>
  %cst_2 = arith.constant dense_resource<__elided__> : tensor<768x768xf32>
  %cst_3 = arith.constant dense<8.000000e+00> : tensor<f32>
  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x12x768xf32>
  %1 = tensor.empty() : tensor<768x768xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<768x768xf32>
  %3 = tensor.empty() : tensor<12x768xf32>
  %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x12x768xf32> into tensor<12x768xf32>
  %8 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<12x768xf32>) -> tensor<12x768xf32>
  %9 = linalg.matmul ins(%collapsed, %2 : tensor<12x768xf32>, tensor<768x768xf32>) outs(%8 : tensor<12x768xf32>) -> tensor<12x768xf32>
  %10 = tensor.empty() : tensor<12x768xf32>
  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %cst_1 : tensor<12x768xf32>, tensor<768xf32>) outs(%10 : tensor<12x768xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.addf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x768xf32>
  %expanded_6 = tensor.expand_shape %11 [[0], [1, 2]] : tensor<12x768xf32> into tensor<12x12x64xf32>
  %12 = tensor.empty() : tensor<12x12x64xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_6 : tensor<12x12x64xf32>) outs(%12 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %14 = tensor.empty() : tensor<12x64x12xf32>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<12x12x64xf32>) outs(%14 : tensor<12x64x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x64x12xf32>
  %16 = tensor.empty() : tensor<12x12x64xf32>
  %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<12x12x64xf32>) outs(%16 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %18 = tensor.empty() : tensor<12x64x12xf32>
  %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15 : tensor<12x64x12xf32>) outs(%18 : tensor<12x64x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x64x12xf32>
  %20 = tensor.empty() : tensor<12x12x12xf32>
  %21 = linalg.fill ins(%cst_0 : f32) outs(%20 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
  %22 = linalg.batch_matmul ins(%17, %19 : tensor<12x12x64xf32>, tensor<12x64x12xf32>) outs(%21 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
  %23 = tensor.empty() : tensor<12x12x12xf32>
  %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%22, %cst_3 : tensor<12x12x12xf32>, tensor<f32>) outs(%23 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.divf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %25 = tensor.empty() : tensor<12x12xf32>
  %26 = linalg.fill ins(%cst : f32) outs(%25 : tensor<12x12xf32>) -> tensor<12x12xf32>
  %27 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<12x12x12xf32>) outs(%26 : tensor<12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = arith.maxf %in, %out : f32
    linalg.yield %45 : f32
  } -> tensor<12x12xf32>
  %28 = tensor.empty() : tensor<12x12x12xf32>
  %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%24, %27 : tensor<12x12x12xf32>, tensor<12x12xf32>) outs(%28 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.subf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %30 = tensor.empty() : tensor<12x12x12xf32>
  %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%29 : tensor<12x12x12xf32>) outs(%30 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = math.exp %in : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %32 = tensor.empty() : tensor<12x12xf32>
  %33 = linalg.fill ins(%cst_0 : f32) outs(%32 : tensor<12x12xf32>) -> tensor<12x12xf32>
  %34 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%31 : tensor<12x12x12xf32>) outs(%33 : tensor<12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = arith.addf %in, %out : f32
    linalg.yield %45 : f32
  } -> tensor<12x12xf32>
  %35 = tensor.empty() : tensor<12x12x12xf32>
  %36 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%31, %34 : tensor<12x12x12xf32>, tensor<12x12xf32>) outs(%35 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.divf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %37 = tensor.empty() : tensor<12x12x12xf32>
  %38 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%36 : tensor<12x12x12xf32>) outs(%37 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x12xf32>
  %39 = tensor.empty() : tensor<12x12x64xf32>
  %40 = linalg.fill ins(%cst_0 : f32) outs(%39 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
  %41 = linalg.batch_matmul ins(%38, %17 : tensor<12x12x12xf32>, tensor<12x12x64xf32>) outs(%40 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
  %42 = tensor.empty() : tensor<12x12x64xf32>
  %43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<12x12x64xf32>) outs(%42 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %expanded_7 = tensor.expand_shape %43 [[0, 1], [2], [3]] : tensor<12x12x64xf32> into tensor<1x12x12x64xf32>
  %collapsed_8 = tensor.collapse_shape %expanded_7 [[0], [1], [2, 3]] : tensor<1x12x12x64xf32> into tensor<1x12x768xf32>
  %44 = hal.tensor.export %collapsed_8 : tensor<1x12x768xf32> -> !hal.buffer_view
  return %44 : !hal.buffer_view
}

Then I end up with 7 dispatches, instead of 13 as was the case in the bug above.. Here is the IR after some relevant passes https://gist.github.com/MaheshRavishankar/11cf9d4ff1beca0d004551f7f7d807b2

Looking at the remaining dispatches, there is a transpose at the very beginning. We can discuss if we want to fuse it with the next dispatch that is doing a matmul. That is a possibility, but as I said above, a better metric to go for is to make the matmul get the better layout and then try to fuse the relayout operation with other operations (instead of indexing on the transpose and trying to fuse that with a matmul. The goal should be to get the matmul in the right layout).

The other remaining inefficiency are two dispatches with a single generic op. These are doing two different transposes. That might be worth looking into where those transposes are coming from and then trying to fix them at the right place. My point is, when it comes to fusion its useful to look at the entire IR from input until the dispatch region formation and try to "clean up" the IR into something that will require very less actual rules within the fusion itself....

ThomasRaoux commented 1 year ago

The last IR I included did have the clean up of batchmatmul to matmul. Of course looking at the whole graph is what we need to do. I don't see how that goes against what I was talking about.

stellaraccident commented 1 year ago

Sorry, I havent had time to look at it, but we are rehashing old things again... I can take a look at it, but trying to fuse the transpose with the gemm to get a gemm transpose is not what I think is the first order solution. If you want the gemm to be the transpose variant, then you should do that and then cancel out the transposes... I have to look at this case in more detail.

Why? If both version of gemm have similar performance I don’t see why we wouldn’t want to fuse those opportunistically.

If we have another practical solution that’s fine with me too but this is something we want to fix

For my own edification: Are the gemm and gemm transpose equivalent in terms of perf? (and is that an invariant for all targets?) What would be the heuristic for choosing one over the other?

Generally, I always consider the presence/lack of transposes in the program to be a non-useful thing to key off of. Sometimes they derive from a programmer over-eagerly trying to massage things into a form good for their case or for any number of practical reasons that have nothing to do with being optimal. Keying off of them in a source program is over-indexing on what happened to be written and isn't really related to what we can know to compile well, ime.

If a heuristic ever favors a gemm vs gemm-transpose (i.e. due to better access pattern or ability to vectorize, etc), then it is better to fixate that at the whole program level first (which will involve inserting a transpose for cases that need to be switched) and then canonicalizing to cancel/propagate. There are some corner cases, but I think this is pretty standard procedure in many systems I have seen.

Maybe it doesn't matter on this target and either is perfectly valid. I know it does on some.

stellaraccident commented 1 year ago

(looks like there was a response from Thomas with some data that showed up in the email thread but not here)

I think I buy that if high level optimizations fail to fixate things to be optimal, not fusing such transposes that remain is the worst of all worlds. Following this train of thought, doesn't it imply that we want both? High level optimizations which attempt to assign optimal layouts for the matmuls and then fallback fusion that swallows any adjacent transposes? Seems separable.

ThomasRaoux commented 1 year ago

Sorry I included an incorrect graph in my previous reply. My experience is that normal matmul vs transposed matmul are close but I can't find the data at the time. I can measure to confirm if this is blocker but I was assuming there are close indeed.

I don't know about other targets but I'm guessing this is not the case. Although when combined with packing it might be.

If a heuristic ever favors a gemm vs gemm-transpose (i.e. due to better access pattern or ability to vectorize, etc), then it is better to fixate that at the whole program level first (which will involve inserting a transpose for cases that need to be switched) and then canonicalizing to cancel/propagate. There are some corner cases, but I think this is pretty standard procedure in many systems I have seen.

I agree with that in general, my point was that since perf are roughly equivalent we can take advantage of it to remove dispatches in cases where propagating is not easy. Here with the reshape I'm not sure there is a way to propagate. (maybe there is but I haven't figure it out at this time). If we have a more general solution that also applies to all platforms of course it is better.

ThomasRaoux commented 1 year ago

(looks like there was a response from Thomas with some data that showed up in the email thread but not here)

Yes sorry I realized I included wrong data.

I think I buy that if high level optimizations fail to fixate things to be optimal, not fusing such transposes that remain is the worst of all worlds. Following this train of thought, doesn't it imply that we want both? High level optimizations which attempt to assign optimal layouts for the matmuls and then fallback fusion that swallows any adjacent transposes? Seems separable.

Yes I think we want both.

silvasean commented 1 year ago

Two things to look at. The batch_matmul is actually just a matmul.

The unit dimensions are specific to the batch=1 case (which almost never happens during training or datacenter inference). Here is a gist with IR for the batch=3 case (I set the sequence length to 17 just so that number is easy to trace through as well). In a user application this would be e.g. the number of sentences being processed at once by a transformer-based model architecture, so in training and datacenter inference it will always be a significant number. IREE performs much worse than the alternatives I benchmarked for such cases because we deep broadcast of the whole weight matrix feeding into the batch matmul. This can be reduced to a regular matmul as well, but through a more general chain of reasoning.

Arguably the frontend should special case this and emit a linalg.matmul in the first place, but it seems reasonable that our backstop fusion heuristics should at least fuse the broadcast of the weight matrix in these cases though instead of materializing the broadcasted weight matrix. The invariant that frontends would expect is "replacing an operand of an op by a broadcast should not deep-materialize the broadcast in memory"; if we don't want to aim for that, then it means that frontends need to aim for "don't emit broadcasts that create tensors asymptotically bigger than those that were present in the original eager execution of the program". Thoughts? I think we already provide a similar invariant -- multiple subsequent elementwise ops do not materialize intermediates -- and as such, as a frontend we don't worry about emitting separate elementwise ops as separate linalg.generics because of this.

image

I don't know about other targets but I'm guessing this is not the case. Although when combined with packing it might be.

Yeah in Ruy we handle all 8 dst,lhs,rhs row/col major layouts and packing papers over it. One details is that since Ruy writes into an unpacked dst, the asm kernels branch and store in either row/column major layout, but our deeper ability to analyze the program might be able to eliminate that as well.

stellaraccident commented 1 year ago

Arguably the frontend should special case this and emit a linalg.matmul in the first place, but it seems reasonable that our backstop fusion heuristics should at least fuse the broadcast of the weight matrix in these cases though instead of materializing the broadcasted weight matrix. The invariant that frontends would expect is "replacing an operand of an op by a broadcast should not deep-materialize the broadcast in memory"; if we don't want to aim for that, then it means that frontends need to aim for "don't emit broadcasts that create tensors asymptotically bigger than those that were present in the original eager execution of the program". Thoughts? I think we already provide a similar invariant -- multiple subsequent elementwise ops do not materialize intermediates -- and as such, as a frontend we don't worry about emitting separate elementwise ops as separate linalg.generics because of this.

I consider the current state a bug and how you state the invariant is likely helpful. Thinking back to code I've seen, I think there are cases where materializing instead of fusing the broadcast is more optimal, but for this kind of compiler, I would rather treat that as the optimization/special case and have the default case select towards limiting materialization.

We can't expect frontends to "get this right". People programming such things (or the libraries emitting them) are conditioned to think of these things as negligible cost -- mostly because library systems almost always handle this with a metadata transform on a tensor object instead of a copy/expansion.

ThomasRaoux commented 1 year ago

So getting back to concrete next steps here. The first obvious step is to convert batch_matmul into matmul when batch size is 1. The reason why it helps is not because we handle matmul better but because we remove unit dims to linalg.generic ops so without that we end up with extra shape casts. I sent a patch for that here: https://reviews.llvm.org/D144294

This doesn't solve the whole problem since there are still transpose in between the matmul and the batch matmul that end up in their own dispatch regions. There are two solutions there. 1) Doing propagation of the tranpose 2) doing greedy fusion of the transpose into the matmul when propagation didn't remove it.

Most likely we need both solutions on the long term. How we get to apply solution (1) here is not clear to me since the reshape seem to be prevent it. I'm happy to try things out if we have a solution for it. Solution (2) is something we can do right now and would most likely get most of the performance for this case. It is not clear to me whether this is good or not on all targets so my plan is to work on enabling this fusion and measure the perf impact on different target, then it can be either enabled in the general fusion or it will have to be done as target specific fusion.

MaheshRavishankar commented 1 year ago

Could we first measure how much we are talking about before building out a solution....

MaheshRavishankar commented 1 year ago

Also this specific shape falls into the "good shape" category. Is this the only shape we care about.... I'd rather fix the big ticket items first than going after few small percentages.

ThomasRaoux commented 1 year ago

Could we first measure how much we are talking about before building out a solution....

There is some estimate in Sean's slides.

Also this specific shape falls into the "good shape" category. Is this the only shape we care about.... I'd rather fix the big ticket items first than going after few small percentages.

I don't understand what this has to do with shapes? The problem would happen on any shape. I'm sure there are other opportunities but this is something we need to fix no matter what. Why do you think we should not address this?

MaheshRavishankar commented 1 year ago

Arguably the frontend should special case this and emit a linalg.matmul in the first place, but it seems reasonable that our backstop fusion heuristics should at least fuse the broadcast of the weight matrix in these cases though instead of materializing the broadcasted weight matrix. The invariant that frontends would expect is "replacing an operand of an op by a broadcast should not deep-materialize the broadcast in memory"; if we don't want to aim for that, then it means that frontends need to aim for "don't emit broadcasts that create tensors asymptotically bigger than those that were present in the original eager execution of the program". Thoughts? I think we already provide a similar invariant -- multiple subsequent elementwise ops do not materialize intermediates -- and as such, as a frontend we don't worry about emitting separate elementwise ops as separate linalg.generics because of this.

I consider the current state a bug and how you state the invariant is likely helpful. Thinking back to code I've seen, I think there are cases where materializing instead of fusing the broadcast is more optimal, but for this kind of compiler, I would rather treat that as the optimization/special case and have the default case select towards limiting materialization.

Fusing broadcasts with the matmul/batch_matmul layer is definitely worth it. I havent looked at this in a bit, but I think I explicitly disallow this fusion cause it is a matter of the backends being able to handle it. Based on my understanding on the CPU backends, it shouldnt be too hard to support this. For CUDA backends since operands are promoted to shared memory, this is one place where the broadcast of the LHS should be managed.

Another approach we could take here is generalize such operations to a linalg.generic and have the backends use the ContractionOpInterface to drive it down the same pipeline as gemm/batch_matmul. Then I think we can easily enable fusion of broadcasts with its producer GEMMs.

MaheshRavishankar commented 1 year ago

Could we first measure how much we are talking about before building out a solution....

There is some estimate in Sean's slides.

Sean's slides suggest these are 5%. Also these are two transposes that could be overlapped in execution. Then "fusion" is going to benefit overall 5%.

Also this specific shape falls into the "good shape" category. Is this the only shape we care about.... I'd rather fix the big ticket items first than going after few small percentages.

I don't understand what this has to do with shapes? The problem would happen on any shape. I'm sure there are other opportunities but this is something we need to fix no matter what. Why do you think we should not address this?

The shapes here put the batch_matmul on the good path. Do all shapes fall in the same place. What is more important to go for 40-50% improvement due to better GEMM/Batchmatmul performance, or potentially 10% improvement (if not lesser from fusion).

I am not opposed to looking at fusion, but first would be good to understand where these transposes are coming from. They might be something introduced earlier that we could avoid to start with.

ThomasRaoux commented 1 year ago

Fusing broadcasts with the matmul/batch_matmul layer is definitely worth it. I havent looked at this in a bit, but I think I explicitly disallow this fusion cause it is a matter of the backends being able to handle it. Based on my understanding on the CPU backends, it shouldnt be too hard to support this. For CUDA backends since operands are promoted to shared memory, this is one place where the broadcast of the LHS should be managed.

Another approach we could take here is generalize such operations to a linalg.generic and have the backends use the ContractionOpInterface to drive it down the same pipeline as gemm/batch_matmul. Then I think we can easily enable fusion of broadcasts with its producer GEMMs.

The generalization can also be done in the backend if needed. I for both SPIRV and CUDA backend I think all such fusion would work by making it a linalg.generic.

Sean's slides suggest these are 5%. Also these are two transposes that could be overlapped in execution. Then "fusion" is going to benefit overall 5%.

5% is for the last output transpose. There are also some transpose in bias operation so those dispatch regions still exist so I assume the impact is closer to the 25%.

The shapes here put the batch_matmul on the good path. Do all shapes fall in the same place. What is more important to go for 40-50% improvement due to better GEMM/Batchmatmul performance, or potentially 10% improvement (if not lesser from fusion).

You mean the problem with scaling? This is already being looked at, the main reason is tile size picking is not tuned. It doesn't have to be a serialized effort.

I am not opposed to looking at fusion, but first would be good to understand where these transposes are coming from. They might be something introduced earlier that we could avoid to start with.

Looking at the links sent by Sean the transpose comes from the input, maybe Sean has a better understanding as I'm not sure why the batch size gets collapsed/expanded.

Improving fusion seems like a something we will need no matter what. What is the reason for pushing back on it?

MaheshRavishankar commented 1 year ago

Improving fusion seems like a something we will need no matter what. What is the reason for pushing back on it?

We should avoid extrapolating from a single case (as we have done in the past many times). Need to look at impact on all models and then decide... I dont think the discussion above with Sean and Stella has anything to do with the transpose. That is more related to broadcasts of inputs... I would rather collect numbers on different shapes and then decide. My main concern is that significant amount of effort was spent on cleaning up a whole host of ad-hoc fusion decisions that were done on individual use cases. It was impossible to reason about these decisions in a uniform way, until it was basically blown up and redone. I do not want to get back into that state.

stellaraccident commented 1 year ago

@silvasean Can you make a call. You've got the workload visibility to do so.

MaheshRavishankar commented 1 year ago

@silvasean maybe one thing to understand is where is this sequence coming from

%expanded = tensor.expand_shape %8 [[0], [1], [2, 3]] : tensor<3x17x768xf32> into tensor<3x17x12x64xf32>
  %9 = tensor.empty() : tensor<3x12x17x64xf32>
  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_\
types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<3x17x12x64xf32>) outs(%9 : tensor<3x12x17x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<3x12x17x64xf32>
  %11 = tensor.empty() : tensor<3x12x64x17xf32>
  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_\
types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10 : tensor<3x12x17x64xf32>) outs(%11 : tensor<3x12x64x17xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<3x12x64x17xf32>
  %collapsed = tensor.collapse_shape %10 [[0, 1], [2], [3]] : tensor<3x12x17x64xf32> into tensor<36x17x64xf32>
  %collapsed_3 = tensor.collapse_shape %12 [[0, 1], [2], [3]] : tensor<3x12x64x17xf32> into tensor<36x64x17xf32>

This is from IR before iree-flow-fusion-of-tensor-ops pass (before any IREE fusion kicks in). This sequence of expand/collapse seems fairly esoteric. If we get an idea where this is coming from it can point to a solution.

In terms of next steps, could we also factor in different shapes. We know that this particular shape is very common, and we have tuned implementations. I think the profiles would look very different for different shapes (I would ideally have something completely shape agnostic, but that is a stretch).

silvasean commented 1 year ago

Looking at the links sent by Sean the transpose comes from the input, maybe Sean has a better understanding as I'm not sure why the batch size gets collapsed/expanded.

This is how it is written in the user's source code HuggingFace example 1, HuggingFace example 2. It is also written as such in the HuggingFace JAX version. Also it is similarly written in nanoGPT.

This reshape/transpose fundamental to how transformers work. I will provide some context into how the self-attention layer works for intuition (a simple intro with diagrams is here).

To summarize the discussion below, there are 5 matmuls with various reshapes, forming a pattern something like matvecish(softmax_rows(outer_productish(linear0(x), linear1(x))), linear2(x)). The "matvec"/"outer product" are "ish" because instead of the special case of a dimension of size 1, they have that dimension of size 64 or 128, which for decently sized workloads is very small compared to the other dimensions. The matvecish(softmax_rows(outer_productish(a, b)), c) computation is called the "scaled dot product attention" computation (e.g. what is done by FlashAttention to avoid materializing the full outer product matrix in memory). The linear[0-2] are standard matmul-biasadd layers.

Summarizing the below discussion on sizes, we are talking about

The embedding_size is split into what are called "heads" (this gives it the name "multi-headed attention"). embedding_size = num_heads * size_of_head -- size_of_head is relatively constant, with size 64 to 128 over 3 orders of magnitude of model size. The transpose/reshape's come from the fact that the first 3 matmuls contract over embedding_size, and the other two treat num_heads as a batch dimension which requires reshape/transpose (since we are calling into named ops with fixed dimension orderings). If we used linalg.generic ops with reduction dimensions to model higher-d matmul-like ops (incorporating transposes and multiple reduction-dims), then we could absorb all the transposes/reshapes quite cleanly since we could properly model embedding_size as two dimensions num_heads and size_of_head.

For context on the actual magnitude of the numbers, Transformer models are known to be relatively insensitive to particular choices of many of the hyperparameters for a given compute/parameter budget and there is a known "sweet spot" for langauge models. So there is actually a pretty small space of meaningful parameters for transformer-based language models since many of the hyperparameters are related. As such I will use the sizes from the GPT-3 paper as representative (especially Table 2.1).

The input to the self-attention layer consists of vectors of a particular embedding size (embeddingsize, or $d{model}$ in Table 2.1). For the 125M parameter "GPT-3 Small" this size is 768 the same as the example IR I gave which was for bert-base-uncased (110M parameters, see how the sizes and choices are very similar between GPT-3 and BERT). For the 175B parameter "GPT-3" this size is 12288 (note: about 1000x more params, but this dimension only increases by 16x). There are [number of sentences, sentence length in tokens] of them. This can vary between training and inference but the sequence length for all GPT-3-family models is 2048, while for BERT it is often 512, and the batch size in a datacenter inference scenario might be 16 or 32 more-or-less depending on latency budget, while for training a GPT-3-like model it will be on the order of 16-64. So we are talking about inputs of size [num_sentences, sequence_length, embedding_size] = [16-1000ish, 128-2048, 768-12288]. They are usually chosen as nice aligned sizes.

Fusing broadcasts with the matmul/batch_matmul layer is definitely worth it. I havent looked at this in a bit, but I think I explicitly disallow this fusion cause it is a matter of the backends being able to handle it. Based on my understanding on the CPU backends, it shouldnt be too hard to support this. For CUDA backends since operands are promoted to shared memory, this is one place where the broadcast of the LHS should be managed.

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

silvasean commented 1 year ago

Concretely I think we have the following action items:

  1. The named ops are really biting us in terms of the embedding_size = num_heads * size_of_head dimension split and forcing reshape/transposes. We already solve this for elementwise ops by using higher-D representations and indexing map permutations to absorb reshape/transposes, and I think we need a similar solution here. I think Nicolas and Thomas are iterating on some tech that could potentially pave the way to reliable codegen of arbitrary linalg.generic that boils down to stamping out an inner matmul primitive. This is probably not a "1 week project" but I think if we put our heads together we can solve it (and probably boost our matmul performance along the way too).

    • Note: this also probably solves the other 2 issues and should improve our general einsum handling too.
    • Idea: Let's form up a small "tiger team" for "generalized matmul codegen" with minimally Thomas, Mahesh, and Nicolas for guiding us on how to best apply the structured codegen tech to this problem. We will want to agree on 1) what is the canonical target-independent higher-D flow-level representation we want for this code 2) given that representation, how do we codegen it efficiently across targets.
  2. The broadcast fusion issue -- we can't currently even benchmark the num_sentences > 1 and get useful numbers due to the broadcast fusion issue (IREE's performance is so bad it isn't even meaningful to measure). I am rewriting my example in JAX and maybe we will get lucky that it won't have such broadcasts. @powderluv -- shouldn't we have hit this big performance issue when doing HF transformers through Torch-MLIR with >1 sentences batched?

  3. I think we need to find a way to get the biasadd fusions. I ran my example with sequence length 512 instead of 128 and the biasadd and transpose fusions are still 15% of the runtime (25% in the original example). I will try to gather numbers for larger attention layers, where I suspect that the larger matmuls will dominate the time more, making the fusion less important.

    • That said, we are trying to build a state of the art ML compiler and biasadd fusion is table stakes. To be clear, the most straightforward PyTorch eager execution gets this biasadd fusion -- currently other than removing dispatching overhead we are still sub-PyTorch-eager on this workload (more dispatches AND slower dispatches). With sequence length 512 we are spending 254us compute time and PyTorch eager is 183us (matmul compute time is 193us for IREE vs 144us for PyTorch, so a significant fraction is still non-matmul cost). Once we get into training and large inference workloads IREE's dispatch efficiency won't hide the underlying performance characteristics here.

It sounds like the higher-D matmul workstream I alluded to in point 1. would potentially solve all 3 issues directly and we should invest in that.

stellaraccident commented 1 year ago

If we come up with a viable plan for #1, that sounds good. Curious how much work we think it entails.

nicolasvasilache commented 1 year ago

The transpose/reshape's come from the fact that the first 3 matmuls contract over embedding_size, and the other two treat num_heads as a batch dimension which requires reshape/transpose (since we are calling into named ops with fixed dimension orderings). If we used linalg.generic ops with reduction dimensions to model higher-d matmul-like ops (incorporating transposes and multiple reduction-dims), then we could absorb all the transposes/reshapes quite cleanly since we could properly model embedding_size as two dimensions num_heads and size_of_head.

This is spot on: one of the key observations from TC days that we should go higher-D whenever it makes sense, this certainly seems to be one of those cases.

This is the inverse of the conventional wisdom of trying to always fit the high-level named op that runs fast. It does require however a good story for the codegen aspect. I believe it is within reach for the generic parts as you mention in your post.

My biggest uncertainty relates to doing something good with the pack/unpack part, but I believe there is consensus within IREE that this is not a very hard problem?

MaheshRavishankar commented 1 year ago

Concretely I think we have the following action items:

  1. The named ops are really biting us in terms of the embedding_size = num_heads * size_of_head dimension split and forcing reshape/transposes. We already solve this for elementwise ops by using higher-D representations and indexing map permutations to absorb reshape/transposes, and I think we need a similar solution here. I think Nicolas and Thomas are iterating on some tech that could potentially pave the way to reliable codegen of arbitrary linalg.generic that boils down to stamping out an inner matmul primitive. This is probably not a "1 week project" but I think if we put our heads together we can solve it (and probably boost our matmul performance along the way too).

+1. Going to higher dimension and propagating reshapes is definitely the fix here. (Nicolas said the same above). This is not a low hanging fruit though. Agreed with Nicolas on this having an unclear relationship with going to higher dimensional ops for data layout.

  • Note: this also probably solves the other 2 issues and should improve our general einsum handling too.
  • Idea: Let's form up a small "tiger team" for "generalized matmul codegen" with minimally Thomas, Mahesh, and Nicolas for guiding us on how to best apply the structured codegen tech to this problem. We will want to agree on 1) what is the canonical target-independent higher-D flow-level representation we want for this code 2) given that representation, how do we codegen it efficiently across targets.
    1. The broadcast fusion issue -- we can't currently even benchmark the num_sentences > 1 and get useful numbers due to the broadcast fusion issue (IREE's performance is so bad it isn't even meaningful to measure). I am rewriting my example in JAX and maybe we will get lucky that it won't have such broadcasts. @powderluv -- shouldn't we have hit this big performance issue when doing HF transformers through Torch-MLIR with >1 sentences batched?

We should drill down here a bit. Why is the performance so terrible. I am not expecting this to be ideal, but it shouldnt tank so much that we cant even get meangingful numbers out of this.

  1. I think we need to find a way to get the biasadd fusions. I ran my example with sequence length 512 instead of 128 and the biasadd and transpose fusions are still 15% of the runtime (25% in the original example). I will try to gather numbers for larger attention layers, where I suspect that the larger matmuls will dominate the time more, making the fusion less important.

    • That said, we are trying to build a state of the art ML compiler and biasadd fusion is table stakes. To be clear, the most straightforward PyTorch eager execution gets this biasadd fusion -- currently other than removing dispatching overhead we are still sub-PyTorch-eager on this workload (more dispatches AND slower dispatches). With sequence length 512 we are spending 254us compute time and PyTorch eager is 183us (matmul compute time is 193us for IREE vs 144us for PyTorch, so a significant fraction is still non-matmul cost). Once we get into training and large inference workloads IREE's dispatch efficiency won't hide the underlying performance characteristics here.

From my perspective, the biasadd fusion is a solved problem modulo bugs. This is the gist that I am seeing with iree-flow-enable-aggressive-fusion on. Looking at this again, the last bias-add and batch-matmul should actually be fused. There seems to be a bug here. I can take a look at this when I get some time. (@pzread FYI in case you get to this before me).

It sounds like the higher-D matmul workstream I alluded to in point 1. would potentially solve all 3 issues directly and we should invest in that.

Agreed with this, but this is not a low-hanging fruit.

MaheshRavishankar commented 1 year ago

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

I dont think I subscribe to this. One of the big value-adds from IREE for me is that you should be able to change from CPU backend to GPU backends without (a) mis-compilations and (b) significant compiler regression. We are in pretty good state AFAICS w.r.t (a), but not w.r.t (b). We need to have less special cased paths, and more generalized codegen paths to make (b) a reality. I think in the long run this is the killer feature that will set us apart. You can run and test your model on CPU (with reasonable performance), switch to a GPU and deploy if you need to. That is a huge win in terms of ease of development and deployment. The issues we are having now is because we have invested more in specialized codegen paths, and not as much in generalized codegen paths.

ThomasRaoux commented 1 year ago

Thanks for detailed explanation @silvasean.

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

This is common problem in IREE from where I stand and often leaves use into dead-lock situations. I think this is definitely something we should address

We need to have less special cased paths, and more generalized codegen paths to make (b) a reality. I think in the long run this is the killer feature that will set us apart. You can run and test your model on CPU (with reasonable performance), switch to a GPU and deploy if you need to. That is a huge win in terms of ease of development and deployment. The issues we are having now is because we have invested more in specialized codegen paths, and not as much in generalized codegen paths.

We can get to functional portability with generalized codegen algorithm but I don't think we will get to full performance portability where enabling transformations in the previous layers will always translate to the same performance benefits across backends. I also don't see how this will allow us to scale to more backends.

The named ops are really biting us in terms of the embedding_size = num_heads * size_of_head dimension split and forcing reshape/transposes. We already solve this for elementwise ops by using higher-D representations and indexing map permutations to absorb reshape/transposes, and I think we need a similar solution here. I think Nicolas and Thomas are iterating on some tech that could potentially pave the way to reliable codegen of arbitrary linalg.generic that boils down to stamping out an inner matmul primitive. This is probably not a "1 week project" but I think if we put our heads together we can solve it (and probably boost our matmul performance along the way too).

  • Note: this also probably solves the other 2 issues and should improve our general einsum handling too.
  • Idea: Let's form up a small "tiger team" for "generalized matmul codegen" with minimally Thomas, Mahesh, and Nicolas for guiding us on how to best apply the structured codegen tech to this problem. We will want to agree on 1) what is the canonical target-independent higher-D flow-level representation we want for this code 2) given that representation, how do we codegen it efficiently across targets.

This is something I can prototype. Codegen already handles generic ops using ContractionInterface the same way it handles matmul. One limitation is that IREE fusion handles named op differently than generic ops, however that can be worked around easily for this case. The main challenge with generic ops is that it is much harder to validate all the potential cases and figure out how to map those to standard matmul configurations. In this case there are two different things that can be tried: 1) Generalizing the matmuls and fusing the broadcast/transpose this way. The codegen then should be able to handle those generic using the matmul pipeline. 2) generalizing and propagate the reshapes to the edge of the graph, I'm not sure this will help significantly here but that would be interesting to try as well.

MaheshRavishankar commented 1 year ago

Thanks for detailed explanation @silvasean.

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

This is common problem in IREE from where I stand and often leaves use into dead-lock situations. I think this is definitely something we should address

We need to have less special cased paths, and more generalized codegen paths to make (b) a reality. I think in the long run this is the killer feature that will set us apart. You can run and test your model on CPU (with reasonable performance), switch to a GPU and deploy if you need to. That is a huge win in terms of ease of development and deployment. The issues we are having now is because we have invested more in specialized codegen paths, and not as much in generalized codegen paths.

We can get to functional portability with generalized codegen algorithm but I don't think we will get to full performance portability where enabling transformations in the previous layers will always translate to the same performance benefits across backends. I also don't see how this will allow us to scale to more backends.

What other backends? Having a LLVM based CPU backend/ LLVM based GPU backend, and SPIR-V backend should cover most cases. We wont get full performance portability, but you start with a higher floor. For someone who wants to deploy ML models on their hardware, one of these paths should let them get to a better starting position.

ThomasRaoux commented 1 year ago

Thanks for detailed explanation @silvasean.

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

This is common problem in IREE from where I stand and often leaves use into dead-lock situations. I think this is definitely something we should address

We need to have less special cased paths, and more generalized codegen paths to make (b) a reality. I think in the long run this is the killer feature that will set us apart. You can run and test your model on CPU (with reasonable performance), switch to a GPU and deploy if you need to. That is a huge win in terms of ease of development and deployment. The issues we are having now is because we have invested more in specialized codegen paths, and not as much in generalized codegen paths.

We can get to functional portability with generalized codegen algorithm but I don't think we will get to full performance portability where enabling transformations in the previous layers will always translate to the same performance benefits across backends. I also don't see how this will allow us to scale to more backends.

What other backends? Having a LLVM based CPU backend/ LLVM based GPU backend, and SPIR-V backend should cover most cases. We wont get full performance portability, but you start with a higher floor. For someone who wants to deploy ML models on their hardware, one of these paths should let them get to a better starting position.

I assume that if IREE/openXLA wants to become a platform it needs to be able to support adding backends to support other hardware (like accelerators) with limited dependencies but I could be wrong. Even with those 3 backends there are quite a few targets and moving lockstep without regressing performance even if desirable is going to be hard in practice.

silvasean commented 1 year ago

We should drill down here a bit. Why is the performance so terrible. I am not expecting this to be ideal, but it shouldnt tank so much that we cant even get meangingful numbers out of this.

I looked at this in Nsight --

This really is not a fair comparison because IREE broadcasts the weight matrix of the first 3 linear layers by a factor of num_sentences, so the resulting linalg.batch_matmul ops have num_sentences (e.g. 16, 128) times less arithmetic intensity (and those broadcasts are expensive too). I tried with --iree-flow-enable-aggressive-fusion but it didn't meaningfully change the profile (I don't see the addf's fused in the IR either) -- perhaps a more recent IREE is needed? (my version is 20230213.429)

I dont think I subscribe to this. One of the big value-adds from IREE for me is that you should be able to change from CPU backend to GPU backends without (a) mis-compilations and (b) significant compiler regression. We are in pretty good state AFAICS w.r.t (a), but not w.r.t (b). We need to have less special cased paths, and more generalized codegen paths to make (b) a reality. I think in the long run this is the killer feature that will set us apart. You can run and test your model on CPU (with reasonable performance), switch to a GPU and deploy if you need to. That is a huge win in terms of ease of development and deployment. The issues we are having now is because we have invested more in specialized codegen paths, and not as much in generalized codegen paths.

This is definitely valuable but I think it is also valuable to have a "glide path" for new backends to ramp up, without having to atomically handle all the possibilities from the get-go. Speaking for Torch-MLIR, we have had to accept that the backends are always going to be at different levels of completeness (op support, static vs dynamic shape support, etc.). Torch-MLIR only has 3 backends but I think IREE is destined to have even more (especially counting out-of-tree backends).

In LLVM if you backend hasn't yet implemented a native popcnt instruction you just tell the instruction selector to decompose it into more primitive ops. When you eventually have engineering resources to implement the native support for popcnt, you then tell it to give you the popcnt natively. Actually, this type of "which decompositions do you want" knob is the only knob we let backends control in Torch-MLIR. "which fusions can you codegen" would be a potential analog in IREE.

benvanik commented 1 year ago

"which fusions can you codegen" would be a potential analog in IREE.

Strongly disagree with that - "which you can codegen efficiently" maybe, but not which you accept. An IREE codegen backend must be able to generate code for any input but is not required to generate fast code - it's like clang/llvm needing to accept all C input even if some backends can't generate optimal code. If needed we can surface this better to users via performance warnings.

Random people implementing special cases for special hardware may be able to make calls around what they can/can't generate but that's not something that we should be assuming in the core compiler for targets that can absolutely run all of this code. If it's possible to emulate a RISC-V CPU running Linux in a Unity shader (https://blog.pimaker.at/texts/rvc1/) we should be able to generate some loops - if we can't then we aren't a compiler but just a janky templating mechanism like most other ML frameworks :)

mfuntowicz commented 1 year ago

Thanks @silvasean for all the insights on this specific BERT case, also adding we (at 🤗) are also very interested to contribute making these optimizations available. Just to give you some insights on our side too: we roughtly have ~30/40k "BERTish" models available on our hub, so it might benefit quite a lot of people along with BERT being our most popular model still 🙂.

From the script your shared to reproduce the attention layer, I wanted to point out you're not using the attention_mask input variable to the model/attention: https://github.com/silvasean/iree-benchmarking/blob/main/attention_layer/attention_layer.py#L102

I the case of batch = 1 it doesn't have any impact, for the other case (batch > 1) this is important to include such input as all the sequences in a batch will most likely not be of the same length (token-wise) and then you need to keep track of padding to avoid including padded tokens inside the attention softmax(s).

Here, in the script, the attention_mask will be filled from the BERT forward call with torch.ones which might be seen as a constant and trimmed out (constant + matmul with ones) (ref in transformers' BERT model).

This also makes it possible to have "smarter" softmax implementations such as "masked_softmax(x, mask) which will avoid extra expensive exp calls (I saw multiple implementation of such in various HW compilers for instance).

FWIW, at 🤗 we have quite a lot customers of our inference products which are sending a single input (batch = 1), especially for generation task (GPT-like, Whisper, etc.) which sequentially generates multiple tokens, it is often a tricky challenge to batch these requests.

allieculp commented 1 year ago

@silvasean @MaheshRavishankar @benvanik This P1 has gone a bit stale. Just checking if we should downgrade in priority - if not we can leave as is.

MaheshRavishankar commented 1 year ago

Its on my radar... but not something I can get to right now (but maybe a couple of weeks). So your call.

allieculp commented 1 year ago

Thanks, leaving as-is and we can circle back.

MaheshRavishankar commented 1 year ago

Got some time to look into this... TL;DR the first step here is to evaluate if backends can handle fusion of transposes/broadcast with consumers effectively. This changes the input access patterns and will need changes to the lowering strategy to handle effectively. So ideally before we think about adjusting the fusion heuristics we need to evaluate 1) Is there a variant of the GEMM (A B or A transpose(B) or transpose(A) B or transpose(A) transpose(B)) that is preferred on a backend. If there is a clear winner we need to effectively ensure that each GEMM (or batch GEMM) is in that form by introducing transposes and then propagate/fuse the transposes. 2) Can backend codegeneration handle fusing broadcasting with consumer gemms/batched-gemms.

After we have this information this can feed into what fusion decision we want to make (and thats fairly simple). To demonstrate the relative ease with which we can add these to the fusion heuristics, https://github.com/MaheshRavishankar/iree/commit/73a3889571310a1f73545747840f782e3d481eb6 is the change that fuses transpose/broadcast with its consumers. The dispatch regions formed after that is

module {
  func.func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %cst = arith.constant dense_resource<__elided__> : tensor<768x768xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<768xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x12x768xf32>
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<12x768xf32>
    %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x12x768xf32> into tensor<12x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<12x768xf32>) -> tensor<12x768xf32>
    %4 = flow.dispatch.region -> (tensor<12x768xf32>) {
      %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x768xf32>
      %16 = linalg.matmul ins(%collapsed, %15 : tensor<12x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<12x768xf32>) -> tensor<12x768xf32>
      %17 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%16, %cst_0 : tensor<12x768xf32>, tensor<768xf32>) outs(%2 : tensor<12x768xf32>) {
      ^bb0(%in: f32, %in_6: f32, %out: f32):
        %18 = arith.addf %in, %in_6 : f32
        linalg.yield %18 : f32
      } -> tensor<12x768xf32>
      flow.return %17 : tensor<12x768xf32>
    }
    %expanded = tensor.expand_shape %4 [[0], [1, 2]] : tensor<12x768xf32> into tensor<12x12x64xf32>
    %5 = tensor.empty() : tensor<12x12x64xf32>
    %6 = tensor.empty() : tensor<768x12xf32>
    %7 = flow.dispatch.region -> (tensor<768x12xf32>) {
      %15 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<12x768xf32>) outs(%6 : tensor<768x12xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x12xf32>
      flow.return %15 : tensor<768x12xf32>
    }
    %expanded_3 = tensor.expand_shape %7 [[0, 1], [2]] : tensor<768x12xf32> into tensor<12x64x12xf32>
    %8 = tensor.empty() : tensor<12x12x12xf32>
    %9 = linalg.fill ins(%cst_1 : f32) outs(%8 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
    %10:2 = flow.dispatch.region -> (tensor<12x12x12xf32>, tensor<12x12x64xf32>) {
      %15 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<12x12x64xf32>) outs(%5 : tensor<12x12x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<12x12x64xf32>
      %16 = linalg.batch_matmul ins(%15, %expanded_3 : tensor<12x12x64xf32>, tensor<12x64x12xf32>) outs(%9 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
      %17 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%16 : tensor<12x12x12xf32>) outs(%8 : tensor<12x12x12xf32>) {
      ^bb0(%in: f32, %out: f32):
        %18 = arith.divf %in, %cst_2 : f32
        linalg.yield %18 : f32
      } -> tensor<12x12x12xf32>
      flow.return %17, %15 : tensor<12x12x12xf32>, tensor<12x12x64xf32>
    }
    %11 = flow.dispatch.region -> (tensor<12x12x12xf32>) {
      %15 = iree_linalg_ext.softmax dimension(2) ins(%10#0 : tensor<12x12x12xf32>) outs(%8 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
      flow.return %15 : tensor<12x12x12xf32>
    }
    %12 = linalg.fill ins(%cst_1 : f32) outs(%5 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
    %13 = flow.dispatch.region -> (tensor<12x12x64xf32>) {
      %15 = linalg.batch_matmul ins(%11, %10#1 : tensor<12x12x12xf32>, tensor<12x12x64xf32>) outs(%12 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
      %16 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15 : tensor<12x12x64xf32>) outs(%5 : tensor<12x12x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<12x12x64xf32>
      flow.return %16 : tensor<12x12x64xf32>
    }
    %expanded_4 = tensor.expand_shape %13 [[0, 1], [2], [3]] : tensor<12x12x64xf32> into tensor<1x12x12x64xf32>
    %collapsed_5 = tensor.collapse_shape %expanded_4 [[0], [1], [2, 3]] : tensor<1x12x12x64xf32> into tensor<1x12x768xf32>
    %14 = hal.tensor.export %collapsed_5 : tensor<1x12x768xf32> -> !hal.buffer_view
    return %14 : !hal.buffer_view
  }
}

So thats getting closer to "the best you can do", except for a single transpose that is by itself... I looked into that a bit more and it might actually be some missed opportunity here, but needs a little bit more work to see how to remove that. If that is a blocker we can look into it, but it seems to be getting to the point of diminishing returns (its a single transpose which should ideally be memcpy speed whereas others are a 3D batch matmul). Also to restate any heuristic is going to have gaps, so its a matter of managing complexity of the heuristics w.r.t benefit the complexity gives.

I think a good starting point is to try inputs of this form

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %cst = arith.constant dense_resource<__elided__> : tensor<768x768xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<768xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x12x768xf32>
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<12x768xf32>
    %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x12x768xf32> into tensor<12x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<12x768xf32>) -> tensor<12x768xf32>
    %4 = flow.dispatch.region -> (tensor<12x768xf32>) {
      %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x768xf32>
      %16 = linalg.matmul ins(%collapsed, %15 : tensor<12x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<12x768xf32>) -> tensor<12x768xf32>
      flow.return %16 : tensor<12x768xf32>
    }
    return %4 : tensor<12x768xf32>
}

and other transpose variants to find if there is a preferred layout. That will inform what we need to do next.

If this is not something to push on, then we can reduce the priority of this one till we pick this up at this point again.

allieculp commented 1 year ago

@jpienaar @mattwalsh @julianwa From discussion today, needs potentially an epic/task list / next steps.

allieculp commented 1 year ago

@MaheshRavishankar Where does this issue land in the einsum discussion we had today?

pjannaty commented 1 year ago

@MaheshRavishankar, @Young768 (Donglin) is interested in working on this. Is anyone from your side actively working on this?

cc @julianwa @jpienaar

MaheshRavishankar commented 1 year ago

@MaheshRavishankar, @Young768 (Donglin) is interested in working on this. Is anyone from your side actively working on this?

cc @julianwa @jpienaar

This one is going to be a bit hard. It looks like a fusion issue.... but it really is how general an einsum do backends want to handle. If we want to just handle fusion of transposes/broadcast with batch_matmul then https://github.com/MaheshRavishankar/iree/commit/73a3889571310a1f73545747840f782e3d481eb6 already does that at Flow level.... To restate what I said earlier, a better starting point is to start with various fused dispatches as input, like

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %cst = arith.constant dense_resource<__elided__> : tensor<768x768xf32>
    %cst_0 = arith.constant dense_resource<__elided__> : tensor<768xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x12x768xf32>
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<12x768xf32>
    %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x12x768xf32> into tensor<12x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<12x768xf32>) -> tensor<12x768xf32>
    %4 = flow.dispatch.region -> (tensor<12x768xf32>) {
      %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x768xf32>
      %16 = linalg.matmul ins(%collapsed, %15 : tensor<12x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<12x768xf32>) -> tensor<12x768xf32>
      flow.return %16 : tensor<12x768xf32>
    }
    return %4 : tensor<12x768xf32>
}

where linalg.generic is a transpose, or broadcast and building out support in the backend would be useful. One place to start is to just take the snippet above as input and compile it with the CUDA backend to see what it gives...

MaheshRavishankar commented 1 year ago

Expanding on this a bit more (and also adding some relevant details on the related issue of einsum handling). There are two ways we could approach this problem. 1) We generalize all our gemm-like operations into linalg.generic and use elementwise operation fusion (here) to fuse transposes/broadcast into the linalg.generic. This then converts the problem into general handling of einsum-like operations in the backends. 2) An alternative strategy is that we dont generalize the named ops (like linalg.matmul/linalg.batch_matmul) into linalg.generic and instead expand dispatch region formation to put producer transpose/broadcast-like operations and consumer linalg.matmul/linalg.batch_matmul operations in the same dispatch. This would allow the backends to use tile and fuse to fuse these operations without having to generalize the named ops. This is basically what https://github.com/MaheshRavishankar/iree/commit/73a3889571310a1f73545747840f782e3d481eb6 does. I can commit that and guard it by a flag for exploration. I think on CPU side this might be "easy" to adapt to (at least easier than the GPU side.

So in terms of tasks lists here, I'd suggest this incremental approach

- [ ] Use inputs with predefined `flow.dispatch.regions` which contains a producer transpose/broadcast-like operations  and consumer matmul/batch-matmul operations (like the example https://github.com/openxla/iree/issues/12214#issuecomment-1541018827) to flush out backend issue
- [ ] Commit https://github.com/MaheshRavishankar/iree/commit/73a3889571310a1f73545747840f782e3d481eb6 to make the default dispatch region formation to enable this path.
Young768 commented 1 year ago

After folding the empty tensors(which holds the tiled output), the backend issue seems to be fixed. I compared the performance of matmul, fused transpose+matmul, and unfused transpose+matmul. The IR are:

matmul:

module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %16 = linalg.matmul ins(%collapsed, %arg1 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
    return %16 : tensor<128x768xf32>
}
}

fused transpose+matmul

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %4 = flow.dispatch.region -> (tensor<128x768xf32>) {
      %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x768xf32>
      %16 = linalg.matmul ins(%collapsed, %15 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
      flow.return %16 : tensor<128x768xf32>
    }
    return %4 : tensor<128x768xf32>
}
}

unfused transpose+matmul

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
    } -> tensor<768x768xf32>
    %16 = linalg.matmul ins(%collapsed, %15 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
    return %16 : tensor<128x768xf32>
}
}

The performance number I got: matmul one: 7855 items_per_second=12.5886k/s fused transpose+matmul: 7923 items_per_second=12.7482k/s unfused transpose+matmul: 1280 items_per_second=1.8357k/s

Now the perf of fused one is as good as just doing matmul.

Please advise the next step. @MaheshRavishankar

MaheshRavishankar commented 1 year ago

Sorry for the delay @Young768 . Thats great! I think this means we should try to find ways in which we can enable this fusion in IREE. This might be the first instance of a fusion that we enable on just CUDA. This should be usable on all backends, but might not do so today... So we can selectively enable this only for CUDA while the rest of the backends catch up. @benvanik added something for this recently here https://github.com/openxla/iree/commit/c9419462fd4e6bd85098ed2afe6efa7104a20ea8 . I have not looked at this in detail, but we can start using this mechanism to enable this fusion.