iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.87k stars 627 forks source link

Crash when compiling mhlo.dynamic_iota. #7897

Closed gnecula closed 2 years ago

gnecula commented 2 years ago

This is from experiments with JAX+IREE and dynamic shapes. This bug is about mhlo.dynamic_iota.

This may be related to #7888 about mhlo.dynamic_reshape.

Repro (put the code below in iree_repro.py and run pythoniree_repro.py:

from iree.compiler import compile_str

CODE = """
module @jit_f.4  {
  func public @main(%arg0: tensor<?xf32> {mhlo.is_same_data_across_replicas} loc(unknown)) -> tensor<?xi32> {
    %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
    %3 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %4 = "mhlo.concatenate"(%3) {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
    %5 = "mhlo.dynamic_iota"(%4) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
    return %5 : tensor<?xi32>
  }
}
"""

Error:

$ python tests/iree_repro.py
Traceback (most recent call last):
  File "/Users/necula/Source/jax/tests/iree_repro.py", line 15, in <module>
    compiled_flatbuffer = compile_str(CODE, target_backends=["dylib"], input_type="mhlo")
  File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/core.py", line 262, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/binaries.py", line 201, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool ireec
Diagnostics:
LLVM ERROR: SmallVector unable to grow. Requested capacity (140704432398336) is larger than maximum value for size type (4294967295)
PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libIREECompilerAggregateCAPI.dylib 0x00000001116a29d7 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4715143
1  libIREECompilerAggregateCAPI.dylib 0x00000001116a18f8 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4710824
2  libIREECompilerAggregateCAPI.dylib 0x00000001116a2ff0 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4716704
3  libsystem_platform.dylib           0x00007ff80c349e2d _sigtramp + 29
4  libIREECompilerAggregateCAPI.dylib 0x0000000114c6eea8 llvm::Module::getDarwinTargetVariantSDKVersion() const + 722760
5  libsystem_c.dylib                  0x00007ff80c280d10 abort + 123
6  libIREECompilerAggregateCAPI.dylib 0x000000011160a2fd mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4090797
7  libIREECompilerAggregateCAPI.dylib 0x0000000111654ae2 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4395922
8  libIREECompilerAggregateCAPI.dylib 0x0000000111654911 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4395457
9  libIREECompilerAggregateCAPI.dylib 0x00000001125a1c61 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::Check::ExpectTrueOp>() + 1613489
10 libIREECompilerAggregateCAPI.dylib 0x00000001125a158d mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::Check::ExpectTrueOp>() + 1611741
11 libIREECompilerAggregateCAPI.dylib 0x0000000114a2519f llvm::Loop::getLocRange() const + 3376191
12 libIREECompilerAggregateCAPI.dylib 0x0000000114c54f31 llvm::Module::getDarwinTargetVariantSDKVersion() const + 616401
13 libIREECompilerAggregateCAPI.dylib 0x0000000114a2fe01 llvm::Loop::getLocRange() const + 3420321
14 libIREECompilerAggregateCAPI.dylib 0x0000000114a28e4d llvm::Loop::getLocRange() const + 3391725
15 libIREECompilerAggregateCAPI.dylib 0x0000000114a2bfab llvm::Loop::getLocRange() const + 3404363
16 libIREECompilerAggregateCAPI.dylib 0x000000011253e4b2 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::Check::ExpectTrueOp>() + 1206018
17 libIREECompilerAggregateCAPI.dylib 0x00000001116d2f1f mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4913103
18 libIREECompilerAggregateCAPI.dylib 0x00000001116d3273 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4913955
19 libIREECompilerAggregateCAPI.dylib 0x00000001116d4814 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 4919492
20 libIREECompilerAggregateCAPI.dylib 0x00000001112383f3 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 85155
21 libIREECompilerAggregateCAPI.dylib 0x00000001112364b0 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 77152
22 libIREECompilerAggregateCAPI.dylib 0x00000001145682ef llvm::MachineFunction::verify(llvm::Pass*, char const*, bool) const + 10589887
23 libIREECompilerAggregateCAPI.dylib 0x00000001112396ff mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 90031
24 libIREECompilerAggregateCAPI.dylib 0x00000001112390d3 mlir::TypeID mlir::detail::TypeIDExported::get<mlir::iree_compiler::IREE::VM::YieldOp>() + 88451
25 dyld                               0x000000011b4834fe start + 462

Invoked with:
 ireec /Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false
benvanik commented 2 years ago

This ends up getting pretty far for me and dying on a TODO - it looks like some additional data-dependent shape handling is required. We really need to be detensoring these things - this is still computing the shape on device and reading it back. @ScottTodd did detensoring die?

That's a 😢 backtrace - no clue why it has some symbols but they are all wrong - if it were stripped I'd not expect to see any of those. @stellaraccident any ideas?

benvanik commented 2 years ago

Actually, there's a root bug here with mhlo.get_dimension_size:

// -----// IR Dump After ConvertShapeToStandard //----- //
module @jit_f.4  {
  func public @main(%arg0: tensor<?xf32> {mhlo.is_same_data_across_replicas}) -> tensor<?xi32> {
    %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %2 = "mhlo.dynamic_iota"(%1) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
    return %2 : tensor<?xi32>
  }
}

// -----// IR Dump After Canonicalizer //----- //
func public @main(%arg0: tensor<?xf32> {mhlo.is_same_data_across_replicas}) -> tensor<?xi32> {
  %0 = mhlo.constant dense<-1> : tensor<1xi32>
  %1 = "mhlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor<?xi32>
  return %1 : tensor<?xi32>
}

the canonicalizer there is folding the dimension query to -1, the sentinel for dynamic dimensions 🤦

There's still potential IREE issues after that, but we also hit some silly stuff with the index_cast:

    %cst = arith.constant dense<-1> : tensor<1xi32>
    %0 = arith.index_cast %cst : tensor<1xi32> to tensor<1xindex>
    %1 = tensor.extract %0[%c0] : tensor<1xindex>
->
  %cst = arith.constant dense<-1> : tensor<1xi32>
  %0 = linalg.init_tensor [1] : tensor<1xindex>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cst : tensor<1xi32>) outs(%0 : tensor<1xindex>) {
  ^bb0(%arg1: i32, %arg2: index):  // no predecessors
    %7 = arith.index_cast %arg1 : i32 to index
    linalg.yield %7 : index
  } -> tensor<1xindex>
  %2 = tensor.extract %1[%c0] : tensor<1xindex>

This is because of the use of i32 instead of index, but also we really shouldn't be doing index casts as dispatches - detensoring would solve that (in this case, hopefully).

benvanik commented 2 years ago

(I'll fix the late-stage stream dialect bug around data-dependent shape query materialization, but the MHLO issue should be routed to https://github.com/tensorflow/mlir-hlo for visibility)

ScottTodd commented 2 years ago

This ends up getting pretty far for me and dying on a TODO - it looks like some additional data-dependent shape handling is required. We really need to be detensoring these things - this is still computing the shape on device and reading it back. @ScottTodd did detensoring die?

It should still be integrated behind the -iree-flow-enable-linalg-detensorize flag, which defaults to off, but I/we need to spent a bit of time thinking through the edge cases and adding test coverage. It's been in a weird spot where it hasn't been necessary for correctness, and I'm not sure if we'd want it to be. In many cases we think it will be helpful for performance, but there are a few cases where that is less certain (this thread on Discord comes to mind: "today if we turned it on we'd create more of these [code patterns that are not well optimized] in models that didn't otherwise have them").

benvanik commented 2 years ago

If people start doing dynamic shape stuff we'll need it in some fashion - all these silly frontends put shapes in tensors. Being correct does not imply being usable and we won't be able to hide behind that next year :)

stellaraccident commented 2 years ago

(pytorch doesn't)

stellaraccident commented 2 years ago

(I'll fix the late-stage stream dialect bug around data-dependent shape query materialization, but the MHLO issue should be routed to https://github.com/tensorflow/mlir-hlo for visibility)

Yeah, I've fixed a few of these shoddy canonicalizers. Will need an upstream patch.

stellaraccident commented 2 years ago

This ends up getting pretty far for me and dying on a TODO - it looks like some additional data-dependent shape handling is required. We really need to be detensoring these things - this is still computing the shape on device and reading it back. @ScottTodd did detensoring die?

That's a 😢 backtrace - no clue why it has some symbols but they are all wrong - if it were stripped I'd not expect to see any of those. @stellaraccident any ideas?

First I'm seeing the symbol issue, but I note that this is on osx without the symbolizer. It happens that those symbols are some of the only ones being emitted with default visibility so the last chance stack dumper is likely just choosing a close match, and those are the closest.

Should make a separate issue for generating release binaries with at least some minimal symbols. Probably worth the cost.

benvanik commented 2 years ago

The issue deeper down stems from these shape tensors getting pulled into dispatch regions weirdly:

func @main(%arg0: !hal.buffer_view {mhlo.is_same_data_across_replicas}) -> !hal.buffer_view attributes {iree.abi.stub} {
  %c-1 = arith.constant -1 : index
  %c0 = arith.constant 0 : index
  %0 = linalg.init_tensor [] : tensor<index>
  %1 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%0 : tensor<index>) {
  ^bb0(%arg1: index):  // no predecessors
    linalg.yield %c-1 : index
  } -> tensor<index>
  %2 = flow.tensor.reshape %1 : tensor<index> -> tensor<1xindex>
  %3 = tensor.extract %2[%c0] : tensor<1xindex>
  %4 = linalg.init_tensor [%3] : tensor<?xi32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%4 : tensor<?xi32>) {
  ^bb0(%arg1: i32):  // no predecessors
    %7 = linalg.index 0 : index
    %8 = arith.index_cast %7 : index to i32
    linalg.yield %8 : i32
  } -> tensor<?xi32>
  %6 = hal.tensor.export %5 : tensor<?xi32>{%3} -> !hal.buffer_view
  return %6 : !hal.buffer_view
}

I can surmise why DispatchLinalgOnTensors pulls in the shape tensor %2 instead of the extracted shape dimension %3:

...
  %1 = flow.tensor.reshape %0 : tensor<index> -> tensor<1xindex>
  %2 = tensor.extract %1[%c0] : tensor<1xindex>
  %3 = flow.dispatch.workgroups[%2, %c1, %c1](%1) : (tensor<1xindex>) -> tensor<?xi32>{%2} =
      (%arg1: !flow.dispatch.tensor<readonly:1xindex>, %arg2: !flow.dispatch.tensor<writeonly:?xi32>) {
    %c0_0 = arith.constant 0 : index
    %5 = flow.dispatch.tensor.load %arg1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1xindex> -> tensor<1xindex>
    %6 = tensor.extract %5[%c0_0] : tensor<1xindex>
    %7 = linalg.init_tensor [%6] : tensor<?xi32>
    %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
    %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
    %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
    %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
    %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
    scf.for %arg3 = %8 to %6 step %9 {
      %10 = affine.min affine_map<(d0, d1)[s0] -> (d0, -d1 + s0)>(%workgroup_size_0, %arg3)[%6]
      %11 = tensor.extract_slice %7[%arg3] [%10] [1] : tensor<?xi32> to tensor<?xi32>
      %12 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%11 : tensor<?xi32>) {
      ^bb0(%arg4: i32):  // no predecessors
        %13 = linalg.index 0 : index
        %14 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%13, %arg3)
        %15 = arith.index_cast %14 : index to i32
        linalg.yield %15 : i32
      } -> tensor<?xi32>
      flow.dispatch.tensor.store %12, %arg2, offsets = [%arg3], sizes = [%10], strides = [1] : tensor<?xi32> -> !flow.dispatch.tensor<writeonly:?xi32>{%6}
    }
    flow.return
  }
  %4 = hal.tensor.export %3 : tensor<?xi32>{%2} -> !hal.buffer_view
  return %4 : !hal.buffer_view
}

I think it's conditional as to whether we want the load/extract to be moved inside of the region - in this case it's bad as it keeps the shape tensor live but in other cases it could remove host/device hazards. Detensoring would make this all moot.

The issue is that flow->stream conversion needs to get the shapes for all the bindings, and here because of the internal shape data dependency within the region my simple analysis can't find it. I'll see if I can't find a workaround to make this all work, but I'm not proud of what is happening here 😢

benvanik commented 2 years ago

☕ still kicking in, but I'm not sure I'll be able to make this work - we can't have that index tensor crossing the dispatch boundary - index bit width is undefined in the case of multiple devices. @MaheshRavishankar can we have the dispatch region formation logic not try to fold in things with tensor as a workaround? index arguments are fine (like the extracted shape dimension) but the tensors aren't.

benvanik commented 2 years ago

(it's fine to move index tensors inside, just so long as we never have I/O edges as index tensors - so if there was a tensor -> cast to tensor -> do stuff on index sequence we could pull all those in so long as the tensor is what crossed the boundary)

MaheshRavishankar commented 2 years ago

might be worth popping the stack up a bit. Having linalg.generic operation operating on index types is strange. I am assuming all "index" operations of this sort should run on the host. Having a linalg.generic operating on index would meant the index computation is done on the host. In this case

 %0 = linalg.init_tensor [] : tensor<index>
  %1 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%0 : tensor<index>) {
  ^bb0(%arg1: index):  // no predecessors
    linalg.yield %c-1 : index
  } -> tensor<index>
  %2 = flow.tensor.reshape %1 : tensor<index> -> tensor<1xindex>
  %3 = tensor.extract %2[%c0] : tensor<1xindex>

should just be

%c-1 = arith.constant -1 : index
benvanik commented 2 years ago

I think we may want them on device at some point depending on how the frontends/users use them - but am fine saying today they shouldn't be. Maybe that is a stepping stone for detensoring as well: detensor all the index tensors only. I'm adding a verifier to flow.dispatch.workgroups so we'll at least get good diagnostics on this there.

MaheshRavishankar commented 2 years ago

Looking further

%c-1 = arith.constant -1 : index
  %c0 = arith.constant 0 : index
  %0 = linalg.init_tensor [] : tensor<index>
  %1 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%0 : tensor<index>) {
  ^bb0(%arg1: index):  // no predecessors
    linalg.yield %c-1 : index
  } -> tensor<index>
  %2 = flow.tensor.reshape %1 : tensor<index> -> tensor<1xindex>
  %3 = tensor.extract %2[%c0] : tensor<1xindex>
  %4 = linalg.init_tensor [%3] : tensor<?xi32>

If I am reading this right, %3 has a value of -1 and then

%4 = linalg.init_tensor[%3] : tensor<?xi32>

is allocating a tensor of size -1 . Something seems off here?

benvanik commented 2 years ago

Yeah, that's the busted shape dim query in mhlo returning the -1 sentinel for the dimension lol - lots of bugs in this bug :P

benvanik commented 2 years ago

Stella's magic const expr eval stuff will take care of that linalg generic returning -1, though I do think there'd be value in linalg being able to do simple ones like that (assuming tosa/mhlo are going to be coming in with things like this - I think we can wait and see a bit longer).

MaheshRavishankar commented 2 years ago

yeah totally. This is just detensoring though. So it already happens to some extent. I'd like to make sure this is valid IR before going down the path of "needs to be fixed in linalg".

burmako commented 2 years ago

After @rsuderman's work on mhlo.get_dimension_size support (fix for canonicalization and lowering to tensor.dim), the example above gets transformed by MHLO input conversion as follows:

$ iree-opt --iree-mhlo-input-transformation-pipeline dynamic_iota.mlir
#map = affine_map<(d0) -> (d0)>
module @jit_f.4 {
  func public @main(%arg0: tensor<?xf32> {mhlo.is_same_data_across_replicas}) -> tensor<?xi32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
    %1 = arith.index_cast %0 : index to i32
    %2 = tensor.from_elements %1 : tensor<i32>
    %3 = tensor.expand_shape %2 [] : tensor<i32> into tensor<1xi32>
    %4 = arith.index_cast %3 : tensor<1xi32> to tensor<1xindex>
    %5 = tensor.extract %4[%c0] : tensor<1xindex>
    %6 = linalg.init_tensor [%5] : tensor<?xi32>
    %7 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%6 : tensor<?xi32>) {
    ^bb0(%arg1: i32):  // no predecessors
      %8 = linalg.index 0 : index
      %9 = arith.index_cast %8 : index to i32
      linalg.yield %9 : i32
    } -> tensor<?xi32>
    return %7 : tensor<?xi32>
  }
}

Compiling that as follows:

from iree.compiler import compile_str

LINALG_CODE = """
#map = affine_map<(d0) -> (d0)>
module @jit_f.4 {
  func public @main(%arg0: tensor<?xf32> {mhlo.is_same_data_across_replicas}) -> tensor<?xi32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
    %1 = arith.index_cast %0 : index to i32
    %2 = tensor.from_elements %1 : tensor<i32>
    %3 = tensor.expand_shape %2 [] : tensor<i32> into tensor<1xi32>
    %4 = arith.index_cast %3 : tensor<1xi32> to tensor<1xindex>
    %5 = tensor.extract %4[%c0] : tensor<1xindex>
    %6 = linalg.init_tensor [%5] : tensor<?xi32>
    %7 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%6 : tensor<?xi32>) {
    ^bb0(%arg1: i32):  // no predecessors
      %8 = linalg.index 0 : index
      %9 = arith.index_cast %8 : index to i32
      linalg.yield %9 : i32
    } -> tensor<?xi32>
    return %7 : tensor<?xi32>
  }
}
"""

compiled_flatbuffer = compile_str(LINALG_CODE, target_backends=["dylib"], input_type="none")

Fails as follows:

iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool ireec
Diagnostics:
ireec: /home/burmako/iree-build/third_party/llvm-project/llvm/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:264: int64_t mlir::detail::ShapedTypeTrait<mlir::TensorType>::getDimSize(unsigned int) const [ConcreteType = mlir::TensorType]: Assertion `idx < getRank() && "invalid index for shaped type"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /home/burmako/iree-build/compiler-api/python_package/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/home/burmako/iree-build/compiler-api/python_package/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false
 #0 0x00007fbd13e80d93 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/burmako/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:565:13
 #1 0x00007fbd13e7ef60 llvm::sys::RunSignalHandlers() /home/burmako/iree/third_party/llvm-project/llvm/lib/Support/Signals.cpp:97:18
 #2 0x00007fbd13e810fa SignalHandler(int) /home/burmako/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:407:1
 #3 0x00007fbd0f52c910 (/lib/x86_64-linux-gnu/libc.so.6+0x3c910)
 #4 0x00007fbd0f52c891 raise ./signal/../sysdeps/unix/sysv/linux/raise.c:50:1
 #5 0x00007fbd0f516536 abort ./stdlib/abort.c:81:7
 #6 0x00007fbd0f51641f get_sysdep_segment_value ./intl/loadmsgcat.c:509:8
 #7 0x00007fbd0f51641f _nl_load_domain ./intl/loadmsgcat.c:970:34
 #8 0x00007fbd0f525212 (/lib/x86_64-linux-gnu/libc.so.6+0x35212)
 #9 0x00007fbd14817cb6 llvm::ArrayRef<long>::operator[](unsigned long) const /home/burmako/iree/third_party/llvm-project/llvm/include/llvm/ADT/ArrayRef.h:255:7
#10 0x00007fbd14817cb6 mlir::detail::ShapedTypeTrait<mlir::TensorType>::getDimSize(unsigned int) const /home/burmako/iree-build/third_party/llvm-project/llvm/tools/mlir/include/mlir/IR/BuiltinTypeInterfaces.h.inc:265:14
#11 0x00007fbd150eede2 mlir::iree_compiler::IREE::Flow::(anonymous namespace)::ConvertTensorFromElementsPattern::matchAndRewrite(mlir::tensor::FromElementsOp, mlir::PatternRewriter&) const /home/burmako/iree/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp:309:36
#12 0x00007fbd17d8ff82 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) /home/burmako/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:201:25
#13 0x00007fbd17b3d9a2 (anonymous namespace)::GreedyPatternRewriteDriver::simplify(llvm::MutableArrayRef<mlir::Region>) /home/burmako/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:250:19
#14 0x00007fbd17b3d9a2 mlir::applyPatternsAndFoldGreedily(llvm::MutableArrayRef<mlir::Region>, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig) /home/burmako/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:385:27
#15 0x00007fbd150e4a41 mlir::applyPatternsAndFoldGreedily(mlir::Operation*, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig) /home/burmako/iree/third_party/llvm-project/llvm/../mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:71:10
#16 0x00007fbd150e4a41 mlir::iree_compiler::IREE::Flow::(anonymous namespace)::ConvertToFlowBeforeDispatchFormation::runOnOperation() /home/burmako/iree/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp:118:16
#17 0x00007fbd13eb4a9f mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:11
#18 0x00007fbd13eb5154 mlir::LogicalResult::succeeded() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#19 0x00007fbd13eb5154 mlir::LogicalResult::failed() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#20 0x00007fbd13eb5154 mlir::failed(mlir::LogicalResult) /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#21 0x00007fbd13eb5154 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::unique_ptr<mlir::Pass, std::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:452:9
#22 0x00007fbd13eb9df8 auto mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()<std::pair<mlir::Operation*, mlir::AnalysisManager> >(std::pair<mlir::Operation*, mlir::AnalysisManager>&) const /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:614:9
#23 0x00007fbd13eb98cb mlir::LogicalResult::succeeded() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#24 0x00007fbd13eb98cb mlir::LogicalResult::failed() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#25 0x00007fbd13eb98cb mlir::failed(mlir::LogicalResult) /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#26 0x00007fbd13eb98cb mlir::LogicalResult mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<std::pair<mlir::Operation*, mlir::AnalysisManager>*, std::vector<std::pair<mlir::Operation*, mlir::AnalysisManager>, std::allocator<std::pair<mlir::Operation*, mlir::AnalysisManager> > > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<std::pair<mlir::Operation*, mlir::AnalysisManager>*, std::vector<std::pair<mlir::Operation*, mlir::AnalysisManager>, std::allocator<std::pair<mlir::Operation*, mlir::AnalysisManager> > > >, __gnu_cxx::__normal_iterator<std::pair<mlir::Operation*, mlir::AnalysisManager>*, std::vector<std::pair<mlir::Operation*, mlir::AnalysisManager>, std::allocator<std::pair<mlir::Operation*, mlir::AnalysisManager> > > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8&) /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:49:11
#27 0x00007fbd13eb5f52 mlir::LogicalResult::succeeded() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#28 0x00007fbd13eb5f52 mlir::LogicalResult::failed() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#29 0x00007fbd13eb5f52 mlir::failed(mlir::LogicalResult) /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#30 0x00007fbd13eb5f52 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:623:7
#31 0x00007fbd13eb4bd7 mlir::detail::OpToOpPassAdaptor::runOnOperation(bool) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:5
#32 0x00007fbd13eb4bd7 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:390:14
#33 0x00007fbd13eb6d06 mlir::LogicalResult::succeeded() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#34 0x00007fbd13eb6d06 mlir::LogicalResult::failed() const /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#35 0x00007fbd13eb6d06 mlir::failed(mlir::LogicalResult) /home/burmako/iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#36 0x00007fbd13eb6d06 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::unique_ptr<mlir::Pass, std::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:452:9
#37 0x00007fbd13eb6d06 mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:696:10
#38 0x00007fbd13eb6b4c mlir::PassManager::run(mlir::Operation*) /home/burmako/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:0
#39 0x00007fbd13a5936f mlir::iree_compiler::translateFromMLIRToVM(mlir::ModuleOp, mlir::iree_compiler::BindingOptions, mlir::iree_compiler::InputDialectOptions, mlir::iree_compiler::HighLevelOptimizationOptions, mlir::iree_compiler::IREE::HAL::TargetOptions, mlir::iree_compiler::IREE::VM::TargetOptions) /home/burmako/iree/iree/compiler/Translation/IREEVM.cpp:0:26
#40 0x00007fbd13a58918 mlir::iree_compiler::translateFromMLIRToVMBytecodeModuleWithFlags(mlir::ModuleOp, llvm::raw_ostream&) /home/burmako/iree/iree/compiler/Translation/IREEVM.cpp:209:17
#41 0x00007fbd1750513c std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)>::operator()(mlir::ModuleOp, llvm::raw_ostream&) const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:568:9
#42 0x00007fbd1750513c mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1::operator()(llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*) const /home/burmako/iree/third_party/llvm-project/mlir/lib/Translation/Translation.cpp:107:12
#43 0x00007fbd1750513c mlir::LogicalResult std::__invoke_impl<mlir::LogicalResult, mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*>(std::__invoke_other, mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*&&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14
#44 0x00007fbd1750513c std::enable_if<__and_<std::__not_<std::is_void<mlir::LogicalResult> >, std::is_convertible<std::__invoke_result<mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*>::type, mlir::LogicalResult> >::value, mlir::LogicalResult>::type std::__invoke_r<mlir::LogicalResult, mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*>(mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*&&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:142:14
#45 0x00007fbd1750513c std::_Function_handler<mlir::LogicalResult (llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*), mlir::TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(llvm::StringRef, std::function<mlir::LogicalResult (mlir::ModuleOp, llvm::raw_ostream&)> const&, std::function<void (mlir::DialectRegistry&)> const&)::$_1>::_M_invoke(std::_Any_data const&, llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*&&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:291:9
#46 0x00007fbd13a5bc48 std::function<mlir::LogicalResult (llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*)>::operator()(llvm::SourceMgr&, llvm::raw_ostream&, mlir::MLIRContext*) const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:568:9
#47 0x00007fbd13a5bc48 mlir::iree_compiler::runIreeTranslateMain(int, char**)::$_0::operator()(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) const /home/burmako/iree/iree/tools/iree_translate_lib.cc:112:12
#48 0x00007fbd13a5b69c mlir::iree_compiler::runIreeTranslateMain(int, char**) /home/burmako/iree/iree/tools/iree_translate_lib.cc:120:16
#49 0x00007fbd0f5177ed __libc_start_main ./csu/../csu/libc-start.c:332:16
#50 0x00000000002016aa _start (/home/burmako/iree-build/compiler-api/python_package/iree/compiler/tools/../_mlir_libs/ireec+0x2016aa)

Invoked with:
 ireec /home/burmako/iree-build/compiler-api/python_package/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/home/burmako/iree-build/compiler-api/python_package/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false

Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers.

However, replacing %6 = linalg.init_tensor [%5] : tensor<?xi32> with %6 = linalg.init_tensor [%0] : tensor<?xi32> (basically short-circuiting packing iota's dim into a tensor and then unpacking it from a tensor) makes compilation succeed, so I'm wondering if we have a pass/passes upstream or in IREE that take care of that. Do you folks have any thoughts about this?

MaheshRavishankar commented 2 years ago

I think this should be part of the detensoring pass. @ScottTodd maybe we should just turn it on. In any case there shouldnt be a crash. This seems related to something in the shape stuff that was put back in temporarily.

rsuderman commented 2 years ago

Oops, let me tweak this pass around. I was hoping to set up the passes such that the packing / unpacking / casting could all be handled by canonicalizers but that depends on each lowering being consistent in whether to cast / extract first.

rsuderman commented 2 years ago

Alright, some fixes should be upstreamed to llvm-project to migrate casts, reshapes, and extractions around. If the shortcircuiting works we should be good to go once the merge happens.

https://reviews.llvm.org/D118201 https://reviews.llvm.org/D118204

jpienaar commented 2 years ago

@rsuderman could you verify that this has been resolved?

rsuderman commented 2 years ago

Added Natasha to verify functional completeness.