iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 572 forks source link

error: pattern listener tracker fail; transform dialect interpreter failed #13419

Closed silvasean closed 1 year ago

silvasean commented 1 year ago

What happened?

The transform-dialect-based lowering seems to be hitting an issue with this test case.

Run with no_attention2.mlir from here

iree-compile --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 no_attention2.mlir

Sorry for the large test case. The failing dispatch looks like this but for some reason isolating it into a small repro doesn't reproduce the failure for me locally.

    builtin.module {
      func.func @_main_dispatch_342_generic_2048x2048_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<2048x2048xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<2048xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<2048xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant 2.048000e+03 : f32
        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x2048xf32>> -> tensor<2048x2048xf32>
        %1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [2048], strides = [1] : !flow.dispatch.tensor<readonly:tensor<2048xf32>> -> tensor<2048xf32>
        %2 = tensor.empty() : tensor<2048xf32>
        %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2048xf32>) -> tensor<2048xf32>
        %4 = linalg.generic {indexing_maps = [#map4, #map8], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<2048x2048xf32>) outs(%3 : tensor<2048xf32>) {
        ^bb0(%in: f32, %out: f32):
          %6 = arith.negf %in : f32
          %7 = arith.addf %out, %6 : f32
          linalg.yield %7 : f32
        } -> tensor<2048xf32>
        %5 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%4, %1, %3 : tensor<2048xf32>, tensor<2048xf32>, tensor<2048xf32>) outs(%2 : tensor<2048xf32>) {
        ^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32):
          %6 = arith.addf %in, %in_1 : f32
          %7 = arith.divf %6, %cst_0 : f32
          %8 = arith.addf %in_2, %7 : f32
          linalg.yield %8 : f32
        } -> tensor<2048xf32>
        flow.dispatch.tensor.store %5, %arg2, offsets = [0], sizes = [2048], strides = [1] : tensor<2048xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048xf32>>
        return
      }
    }

Full error log link

Steps to reproduce your issue

See above

What component(s) does this issue relate to?

Compiler

Version information

iree.git @ ab37989652aed11f7f46498c09b9ac515c83eaa3

Additional context

No response

stellaraccident commented 1 year ago

@ftynse can you have a look?

ftynse commented 1 year ago

I don't seem to be able to reproduce this. Getting a different error about func.func @_main having mismatching argument types and signature.

<unknown>:0: error: 'func.func' op type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('tensor<i32>')
ftynse commented 1 year ago

If I disable the verifier, I can get slightly further to error: failed to legalize unresolved materialization from '!stream.resource<*>', 'index' to 'i32' that remained live after conversion, but this comes from the Stream level. Transform dialect interpreter runs later than that and I cannot reach it.

silvasean commented 1 year ago

That seems like the error from https://github.com/llvm/llvm-project/issues/62249 (I have a workaround patched locally). (I haven't synced to head due to https://github.com/openxla/iree/issues/13189). To work around it you can disable the detensorize pass as described in step 1 here: https://github.com/openxla/iree/issues/13202#issue-1677226598

silvasean commented 1 year ago

FYI if it is useful, I worked around this particular issue in the program and hit another instance of this on another dispatch region. Seems to be another reduction.

  flow.executable private @_main_dispatch_949 {
    flow.executable.export public @_main_dispatch_949_generic_8192x2048_f32 workgroups(%arg0: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @_main_dispatch_949_generic_8192x2048_f32(%arg0: !flow.dispatch.tensor<readonly:tensor<2048x8192xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<8192xf32>>, %arg2: !flow.dispatch.tensor<readonly:tensor<8192xi1>>, %arg3: !flow.dispatch.tensor<writeonly:tensor<8192xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [2048, 8192], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x8192xf32>> -> tensor<2048x8192xf32>
        %1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [8192], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8192xf32>> -> tensor<8192xf32>
        %2 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [8192], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8192xi1>> -> tensor<8192xi1>
        %3 = tensor.empty() : tensor<8192xf32>
        %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8192xf32>) -> tensor<8192xf32>
        %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<2048x8192xf32>) outs(%4 : tensor<8192xf32>) {
        ^bb0(%in: f32, %out: f32):
          %7 = arith.addf %out, %in : f32
          linalg.yield %7 : f32
        } -> tensor<8192xf32>
        %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2, %5, %4 : tensor<8192xf32>, tensor<8192xi1>, tensor<8192xf32>, tensor<8192xf32>) outs(%3 : tensor<8192xf32>) {
        ^bb0(%in: f32, %in_0: i1, %in_1: f32, %in_2: f32, %out: f32):
          %7 = arith.addf %in_2, %in_1 : f32
          %8 = arith.select %in_0, %7, %cst : f32
          %9 = arith.addf %in, %8 : f32
          linalg.yield %9 : f32
        } -> tensor<8192xf32>
        flow.dispatch.tensor.store %6, %arg3, offsets = [0], sizes = [8192], strides = [1] : tensor<8192xf32> -> !flow.dispatch.tensor<writeonly:tensor<8192xf32>>
        return
      }
    }
  }
ftynse commented 1 year ago

The issue comes from upstream for some reason failing the tracking when an op is "replaced" with another pre-existing op. This sounds perfectly valid in, e.g., CSE. Shortcutting that with

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 6284b41302e2..d43234bb6e48 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -181,8 +181,12 @@ transform::TrackingListener::findReplacementOp(Operation *op,
     return nullptr;

   // If the replacement op is not a new op, drop the mapping.
-  if (!isNewOp(defOp))
+#if 0
+  if (!isNewOp(defOp)) {
+    llvm::errs() << "is not a new op " << *defOp << "\n";
     return nullptr;
+  }
+#endif

   return defOp;
 }

removes this particular issue. @matthias-springer any reason why the replacement must be a new op?

ftynse commented 1 year ago

Also with that shortcut, we start hitting another issue:

/home/silvasean/paxml/praxis/praxis/layers/normalizations.py:367:0: error: a handle passed as operand #0 and consumed by this operation points to a payload entity more than once
/home/silvasean/paxml/praxis/praxis/layers/normalizations.py:367:0: note: repeated target op

Where the problematic transform op is

%5 = transform.structured.fuse_into_containing_op %2 into %forall_op_3

and the payload op is

%10 = linalg.fill ins(%cst : f32) outs(%extracted_slice_2 : tensor<1xf32>) -> tensor<1xf32>

For some reason, %2 points to two operations. It is produced by the chain

%0:4 = transform.iree.match_callback failures(propagate) "reduction"(%arg0) : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = transform.structured.fuse_into_containing_op %0#1 into %forall_op

so it must be only linalg.fill. Further investigation shows that %2 is being assigned two payloads by fuse_into_containing_op. I suspect we may be hitting this TODO:

https://github.com/llvm/llvm-project/blob/3fb067f7ba8e7fee66b0740705f4bc767638ccc7/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp#L641-L645

but I ran out of time. Will pick up when I get meeting-free time.

allieculp commented 1 year ago

@ftynse Setting this to P1, in progress. Let us know when you have an update!

allieculp commented 1 year ago

@nicolasvasilache for visibility

nicolasvasilache commented 1 year ago

Started looking at this at IREE ToT but I get:

LLVM ERROR: operation destroyed but still has uses

Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace.
...

#15 0x00007fe6267b32ad mlir::iree_compiler::IREE::Stream::(anonymous namespace)::ScheduleExecutionPass::runOnOperation() /usr/local/google/home/ntv/github/iree/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp:304:25

@silvasean any clue how to get past this ?

nicolasvasilache commented 1 year ago

In the absence of a repro at ToT, I took @silvasean 's latest example and turned it to:

func.func @_main_dispatch_949_generic_8192x2048_f32(%0: tensor<2048x8192xf32>, %1: tensor<8192xf32>, %2: tensor<8192xi1>) -> tensor<8192xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %3 = tensor.empty() : tensor<8192xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8192xf32>) -> tensor<8192xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<2048x8192xf32>) outs(%4 : tensor<8192xf32>) {
  ^bb0(%in: f32, %out: f32):
    %7 = arith.addf %out, %in : f32
    linalg.yield %7 : f32
  } -> tensor<8192xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2, %5, %4 : tensor<8192xf32>, tensor<8192xi1>, tensor<8192xf32>, tensor<8192xf32>) outs(%3 : tensor<8192xf32>) {
  ^bb0(%in: f32, %in_0: i1, %in_1: f32, %in_2: f32, %out: f32):
    %7 = arith.addf %in_2, %in_1 : f32
    %8 = arith.select %in_0, %7, %cst : f32
    %9 = arith.addf %in, %8 : f32
    linalg.yield %9 : f32
  } -> tensor<8192xf32>
  return %6: tensor<8192xf32>
}
iree-compile --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80

runs properly on that input (this does not help us much as Sean reported that the original issue did not trigger on a single dispatch / function).

silvasean commented 1 year ago

@nicolasvasilache The LLVM ERROR: operation destroyed but still has uses looks similar to the error reported here: https://github.com/openxla/iree/issues/13370#issuecomment-1535308074 And tracked here: https://github.com/openxla/iree/issues/13459

silvasean commented 1 year ago

@ftynse -- just to confirm, were able to reproduce the issue and it seemed like a "real issue"? I.e. not some weirdness due to memory corruption or something coming from the underlying issue of LLVM ERROR: operation destroyed but still has uses. I'm finding it curious that this doesn't reproduce with a single dispatch region.

silvasean commented 1 year ago

I found a workaround: --iree-codegen-llvmgpu-use-transform-dialect= --iree-codegen-llvmgpu-enable-transform-dialect-jit=false on the iree-compile command line.

ftynse commented 1 year ago

@ftynse -- just to confirm, were able to reproduce the issue and it seemed like a "real issue"?

Yes, I was able to reproduce that. It's a real issue (see https://github.com/openxla/iree/issues/13419#issuecomment-1537142315) that needs an upstream fix and some design thinking. I suspect that the upstream behavior of ignoring the operation replaced with itself being motivated by what happens here when it is disabled. I'm OOO right now and will get back to this when I'm back next week.

allieculp commented 1 year ago

@silvasean Sounds like @ftynse is back next week, is this okay timing-wise? Or else we can have @nicolasvasilache take a look.

silvasean commented 1 year ago

That is probably fine (~1 day left in Europe for the week anyway)

matthias-springer commented 1 year ago

@matthias-springer any reason why the replacement must be a new op?

This is overly strict. https://reviews.llvm.org/D150429 removes this restriction.

nicolasvasilache commented 1 year ago

thanks, approved!

On Fri, May 12, 2023 at 10:07 AM Matthias Springer @.***> wrote:

@matthias-springer https://github.com/matthias-springer any reason why the replacement must be a new op?

This is overly strict. https://reviews.llvm.org/D150429 removes this restriction.

— Reply to this email directly, view it on GitHub https://github.com/openxla/iree/issues/13419#issuecomment-1545347914, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNNU5GE7YQWG55PF32QNZ3XFXVSXANCNFSM6AAAAAAXWNL6Y4 . You are receiving this because you were mentioned.Message ID: @.***>

-- N

ftynse commented 1 year ago

The first issue (pattern listener tracker fail) should be fixed by https://reviews.llvm.org/D150429.

The second issue that appears after that fix is due to the following situation:

This can be fixed upstream by relaxing fusion to accept repeated pointees in handles. After this fix, I'm running into what appears to be an instance of #13202 (log: https://gist.github.com/ftynse/46171ece65d4e9a139cabf5549d3b5cb), which suggests that the transform dialect-related issues are fixed.

ftynse commented 1 year ago

https://reviews.llvm.org/D150561 should fix the second issue. @silvasean could you please cherry-pick and check if it is indeed the case?

silvasean commented 1 year ago

Interesting. https://github.com/openxla/iree/issues/13202 should be fixed at head. (I believe I was seeing this error or that one depending on func.func visitation order in the pass manager, so they may mask each other)

I tried applying D150561 but the issue persists (error log - seems the same)

ftynse commented 1 year ago

Did you also apply https://reviews.llvm.org/D150429?

I have not tried that at a recent head, the head from a ~week ago was giving me a registeration error.

silvasean commented 1 year ago

With D150429 additionally applied, the error appears to be gone! Do we know when all the necessary fixes will land in IREE head?

ftynse commented 1 year ago

Nice! I don't know how IREE handles upstream integrates and couldn't find any relevant documentation. https://github.com/openxla/iree/pull/13651 attempts to integrate past the relevant upstream changes, but is currently blocked by mlir-hlo cmake build being broken.

ScottTodd commented 1 year ago

I don't know how IREE handles upstream integrates and couldn't find any relevant documentation.

https://github.com/openxla/iree/tree/main/build_tools/scripts/integrate

ftynse commented 1 year ago

13651 attempts to integrate past the relevant upstream changes, ~but is currently blocked by mlir-hlo cmake build being broken.~

Solved that, now blocked by end-to-end tests stored outside the repo (?) that fail because of TOSA upstream changes.

allieculp commented 1 year ago

@ftynse @silvasean Is this fixed?

ftynse commented 1 year ago

Still waiting for the upstream integrate (#13666), otherwise should be fine.

On Fri, May 19, 2023, 20:20 Allie Culp @.***> wrote:

@ftynse https://github.com/ftynse @silvasean https://github.com/silvasean Is this fixed?

— Reply to this email directly, view it on GitHub https://github.com/openxla/iree/issues/13419#issuecomment-1555062408, or unsubscribe https://github.com/notifications/unsubscribe-auth/AALRG22UBJQFWJ2TU6WTK2DXG62WVANCNFSM6AAAAAAXWNL6Y4 . You are receiving this because you were mentioned.Message ID: @.***>

MaheshRavishankar commented 1 year ago

@silvasean the required commit has made it into IREE.

silvasean commented 1 year ago

Okay let's close this then (I verified it with the patch applied locally in https://github.com/openxla/iree/issues/13419#issuecomment-1550428248)