iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 611 forks source link

Crash: T = mlir::Value]: Assertion `isa<T>(*this) && "Invalid accessor called"' failed. #18562

Closed pdhirajkumarprasad closed 1 month ago

pdhirajkumarprasad commented 1 month ago

What happened?

For given IR

module {
  func.func @"torch-jit-export"(%arg0: !torch.vtensor<[35,1],si64>, %arg1: !torch.vtensor<[2,1,200],f32>, %arg2: !torch.vtensor<[2,1,200],f32>, %arg3: !torch.vtensor<[33278,200],f32>) -> !torch.vtensor<[35,1,200],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.1"} {
    %11 = torch.operator "onnx.Gather"(%arg3, %arg0) : (!torch.vtensor<[33278,200],f32>, !torch.vtensor<[35,1],si64>) -> !torch.vtensor<[35,1,200],f32> 
    return %11: !torch.vtensor<[35, 1,200],f32>
  }
}

Seeing assertion:

iree-compile: /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/PointerUnion.h:156: T llvm::PointerUnion<mlir::Attribute, mlir::Value>::get() const [PT = <mlir::Attribute, mlir::Value>, T = mlir::Value]: Assertion `isa<T>(*this) && "Invalid accessor called"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.  Program arguments: iree-compile --iree-hal-target-backends=llvm-cpu -o out.vmfb --iree-input-demote-i64-to-i32 model.torch_onnx.mlir
 #0 0x00007f25b504c9c7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:13
 #1 0x00007f25b504ac00 llvm::sys::RunSignalHandlers() /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Signals.cpp:106:18
 #2 0x00007f25b504d08a SignalHandler(int) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #3 0x00007f25ae7eb520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f25ae83f9fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x00007f25ae83f9fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #6 0x00007f25ae83f9fc pthread_kill ./nptl/pthread_kill.c:89:10
 #7 0x00007f25ae7eb476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007f25ae7d17f3 abort ./stdlib/abort.c:81:7
 #9 0x00007f25ae7d171b _nl_load_domain ./intl/loadmsgcat.c:1177:9
#10 0x00007f25ae7e2e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x00007f25b974d9e7 mlir::Operation::getOpResultImpl(unsigned int) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:1006:5
#12 0x00007f25b974d9e7 mlir::Operation::getResult(unsigned int) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:402:54
#13 0x00007f25b974d9e7 mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<mlir::arith::ConstantOp>::getResult() /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/OpDefinition.h:699:33
#14 0x00007f25b974d9e7 mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<mlir::arith::ConstantOp>::operator mlir::Value() /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/OpDefinition.h:704:54
#15 0x00007f25b974d9e7 mlir::inferExpandShapeOutputShape(mlir::OpBuilder&, mlir::Location, mlir::ShapedType, llvm::ArrayRef<llvm::SmallVector<long, 2u>>, llvm::ArrayRef<mlir::OpFoldResult>) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Arith/Utils/Utils.cpp:71:9
#16 0x00007f25b957e71e std::_Optional_base_impl<llvm::SmallVector<mlir::OpFoldResult, 6u>, std::_Optional_base<llvm::SmallVector<mlir::OpFoldResult, 6u>, false, false>>::_M_is_engaged() const /bin/../lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:471:58
#17 0x00007f25b957e71e std::optional<llvm::SmallVector<mlir::OpFoldResult, 6u>>::operator bool() const /bin/../lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:985:22
#18 0x00007f25b957e71e mlir::tensor::ExpandShapeOp::inferOutputShape(mlir::OpBuilder&, mlir::Location, mlir::RankedTensorType, llvm::ArrayRef<llvm::SmallVector<long, 2u>>, llvm::ArrayRef<mlir::OpFoldResult>) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp:1675:8
#19 0x00007f25b957eb93 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:130:46
#20 0x00007f25b957eb93 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:135:49
#21 0x00007f25b957eb93 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:487:42
#22 0x00007f25b957eb93 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:588:9
#23 0x00007f25b957eb93 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:1198:19
#24 0x00007f25b957eb93 mlir::tensor::ExpandShapeOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::Value, llvm::ArrayRef<llvm::SmallVector<long, 2u>>) /proj/rdi/staff/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp:1699:29

Steps to reproduce your issue

command:

iree-compile --iree-hal-target-backends=llvm-cpu -o out.vmfb --iree-input-demote-i64-to-i32 model.torch_onnx.mlir

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

hanhanW commented 1 month ago

I usually add -mlir-print-ir-before-all -mlir-print-ir-after-all -mlir-disable-threading 2> ~/log to check the failing pass, it shows that it fails in ConvertTorchToLinalg pass. Here is the IR before the pass:

func.func @"torch-jit-export"(%arg0: !torch.vtensor<[35,1],si64>, %arg1: !torch.vtensor<[2,1,200],f32>, %arg2: !torch.vtensor<[2,1,200],f32>, %arg3: !torch.vtensor<[33278,200],f32>) -> !torch.vtensor<[35,1,200],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.1"} {
  %int-1 = torch.constant.int -1
  %int33278 = torch.constant.int 33278
  %int0 = torch.constant.int 0
  %int1 = torch.constant.int 1
  %0 = torch.aten.lt.Scalar %arg0, %int0 : !torch.vtensor<[35,1],si64>, !torch.int -> !torch.vtensor<[35,1],i1>
  %1 = torch.aten.add.Scalar %arg0, %int33278, %int1 : !torch.vtensor<[35,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[35,1],si64>
  %2 = torch.aten.where.self %0, %1, %arg0 : !torch.vtensor<[35,1],i1>, !torch.vtensor<[35,1],si64>, !torch.vtensor<[35,1],si64> -> !torch.vtensor<[35,1],si64>
  %3 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
  %4 = torch.aten.view %2, %3 : !torch.vtensor<[35,1],si64>, !torch.list<int> -> !torch.vtensor<[35],si64>
  %5 = torch.aten.index_select %arg3, %int0, %4 : !torch.vtensor<[33278,200],f32>, !torch.int, !torch.vtensor<[35],si64> -> !torch.vtensor<[35,200],f32>
  %6 = torch.aten.unsqueeze %5, %int1 : !torch.vtensor<[35,200],f32>, !torch.int -> !torch.vtensor<[?,?,?],f32>
  %7 = torch.tensor_static_info_cast %6 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[35,1,200],f32>
  return %7 : !torch.vtensor<[35,1,200],f32>
}
vinayakdsci commented 1 month ago

This failure comes from the canonicalization of AtenUnflattenIntOp. It was introduced sometime back, and canonicalizes the op into a combination of Unsqueeze and StaticInfoCast. Unsqueeze emits dynamic dims in the output, which MLIR does not expect during shape inference of ExpandShape, which is called in linalg lowering of Unsqueeze, hence the assertion failure.

Here is the relevant part of the code: https://github.com/vinayakdsci/torch-mlir/blob/99848265c388099f500de9eac235bf0e2c9ccc0d/lib/Dialect/Torch/IR/TorchOps.cpp#L2173. This would require fixing the result type of the Unsqueeze op, to get the correct IR.