iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

[CUDA] Optimize einsum-related meso-benchmark to parity #13271

Open silvasean opened 1 year ago

silvasean commented 1 year ago

What happened?

The IR below probably has a couple things going wrong in it that we will need to peel apart piece by piece. This benchmark takes 30ms on XLA:GPU but on IREE just the einsum dispatch takes over 6 seconds, so the overall compuation is >200x slower.

This is extracted from Python code which might be easier to read for some folks (link). In particular there is an einsum which seems to be the source of some problems (this config uses the "outgoing" einsum equation).

The IR snippet is here and in particular the einsum is here. If you want to run this snippet you can use a command line like:

iree-compile --iree-hal-target-backends=cuda --iree-hal-cuda-llvm-target-arch=sm_80 ir34.linalg.mlir -o ir34.vmfb
iree-benchmark-module --device=cuda --function=main --module=ir34.vmfb --flagfile=ir34.flagfile

With this flagfile and ir34.linalg.mlir taken from the link above.


The first thing I notice when I look at the slow dispatch that is taking 6 seconds is that it consists of this linalg.generic

        %9 = tensor.empty() : tensor<128x702x702xf32>
        %10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<128x702x702xf32>) -> tensor<128x702x702xf32>
        %11 = linalg.generic {indexing_maps = [#map3, #map4, #map5, #map4, #map5, #map6, #map7, #map5, #map7, #map5, #map8], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%0, %1, %2, %3, %4, %0, %5, %6, %7, %8 : tensor<702x702xf32>, tensor<702x702x128xf32>, tensor<128xf32>, tensor<702x702x128xf32>, tensor<128xf32>, tensor<702x702xf32>, tensor<702x702x128xf32>, tensor<128xf32>, tensor<702x702x128xf32>, tensor<128xf32>) outs(%10 : tensor<128x702x702xf32>) {
        ^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32, %in_8: f32, %in_9: f32, %out: f32):
          %12 = arith.addf %in_6, %in_7 : f32
          %13 = arith.addf %in_1, %in_2 : f32
          %14 = arith.addf %in_8, %in_9 : f32
          %15 = arith.negf %14 : f32
          %16 = math.exp %15 : f32
          %17 = arith.addf %16, %cst : f32
          %18 = arith.addf %in_3, %in_4 : f32
          %19 = arith.negf %18 : f32
          %20 = math.exp %19 : f32
          %21 = arith.divf %cst, %17 : f32
          %22 = arith.mulf %in_5, %12 : f32
          %23 = arith.addf %20, %cst : f32
          %24 = arith.divf %cst, %23 : f32
          %25 = arith.mulf %in, %13 : f32
          %26 = arith.mulf %22, %21 : f32
          %27 = arith.mulf %25, %24 : f32
          %28 = arith.mulf %27, %26 : f32
          %29 = arith.addf %out, %28 : f32
          linalg.yield %29 : f32
        } -> tensor<128x702x702xf32>

This looks like a really poor fusion decision. Looking at the dispatch graph dumped by --iree-flow-dump-dispatch-graph (dot), it looks like we have basically fused into the input of the einsum various things: biasadds from previous linear layers, sigmoids (the math.exp's) and some elementwise multiplications. This results in significant recomputation.

That's probably enough to get started on working on this. I will post updates as I dig into other aspects here.

Steps to reproduce your issue

See above

What component(s) does this issue relate to?

Compiler

Version information

iree.git @ ab37989652aed11f7f46498c09b9ac515c83eaa3

Additional context

No response

MaheshRavishankar commented 1 year ago

Makes sense. We shouldnt be fusing that. Instead of https://github.com/openxla/iree/blob/1fd449b7b55f87d335ec67666499bb09aedf10f9/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp#L86 we should be fusing in the consumer only if the indexing map in the use is a permutation (apart from the broadcast case).

Do you have the input IR. Should be a simple fix.

silvasean commented 1 year ago

Yep, the input IR is here: https://gist.github.com/silvasean/e18c111db26699a6f18acc6037d5a00a

MaheshRavishankar commented 1 year ago

13308 fixes the fusion issue. Need to evaluate impact of that change on existing models.

MaheshRavishankar commented 1 year ago

This is fixed by #13308 but landing that is blocked by #13189

silvasean commented 1 year ago

Thanks @MaheshRavishankar. That looks like it avoids some of the "undesirable" fusions and improves the runtime to ~600ms from the ~6 seconds it was before. It looks like there is still 20x to go to reach XLA:GPU's 30ms baseline.

MaheshRavishankar commented 1 year ago

I'll be happy to land this... but it is blocked on downstream issues. I have tagged all of them here....

allieculp commented 1 year ago

Quick update: still blocked by #13189

MaheshRavishankar commented 1 year ago

The fusion issue is fixed by #13308 . The performance issue should be fixed at ToT (by https://github.com/openxla/iree/pull/13468 ) . Please verify and close.

silvasean commented 1 year ago

With those fixes, this meso-benchmark now takes 65ms on IREE which is still over 2x off from the 30ms from XLA:GPU, so I think it makes sense to keep this issue open as there is more work to be done to reach parity.

MaheshRavishankar commented 1 year ago

Could we re-title it. Also dropping myself from assignee list and moving it to Sean.

silvasean commented 1 year ago

Changed title. I want to emphasize that the work to be done here is to optimize it to parity, rather than any particular fix. Next step is probably for me to dive in and do a first-order performance analysis to identify remaining gaps.

silvasean commented 1 year ago

TODO: Need to re-evaluate how big of a fraction of the e2e workload this is now after these fixes have landed.