nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

T llvm::PointerUnion<mlir::Attribute, mlir::Value>::get() const [PT = <mlir::Attribute, mlir::Value>, T = mlir::Value]: Assertion `isa<T>(*this) && "Invalid accessor called"' failed. #884

Open pdhirajkumarprasad opened 1 week ago

pdhirajkumarprasad commented 1 week ago

For the given IR

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,224,224],f32>, %arg1: !torch.vtensor<[1,196,128],f32>, %arg2: !torch.vtensor<[128,256],f32>) -> !torch.vtensor<[?,4,196,196],f32>    attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %17 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<4x196xf32>} : () -> !torch.vtensor<[4,196],f32> 
    %22 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<128xf32>} : () -> !torch.vtensor<[128],f32> 
    %23 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<128xf32>} : () -> !torch.vtensor<[128],f32> 
    %24 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<128xf32>} : () -> !torch.vtensor<[128],f32> 
    %25 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<128xf32>} : () -> !torch.vtensor<[128],f32> 
    %26 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %27 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %28 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %29 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %213 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1.0> : tensor<128x256xf32>} : () -> !torch.vtensor<[128,256],f32> 
    %214 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Reshape_1235> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %215 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Reshape_1239> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %220 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1.0> : tensor<128x128xf32>} : () -> !torch.vtensor<[128,128],f32> 
    %221 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1.0> : tensor<128x256xf32>} : () -> !torch.vtensor<[128,256],f32> 
    %308 = torch.operator "onnx.MatMul"(%arg1, %arg2) : (!torch.vtensor<[1,196,128],f32>, !torch.vtensor<[128,256],f32>) -> !torch.vtensor<[1,196,256],f32> 
    %355 = torch.operator "onnx.Reshape"(%308, %214) : (!torch.vtensor<[1,196,256],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> 
    %356 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %357:3 = torch.operator "onnx.Split"(%355, %356) {torch.onnx.axis = 3 : si64} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[3],si64>) -> (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>) 
    %358 = torch.operator "onnx.Transpose"(%357#0) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %359 = torch.operator "onnx.Transpose"(%357#1) {torch.onnx.perm = [0 : si64, 2 : si64, 3 : si64, 1 : si64]} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %360 = torch.operator "onnx.Transpose"(%357#2) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %361 = torch.operator "onnx.MatMul"(%358, %359) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %362 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %363 = torch.operator "onnx.Mul"(%361, %362) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %364 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<196x196xsi64>} : () -> !torch.vtensor<[196,196],si64> 
    %365 = torch.operator "onnx.Gather"(%17, %364) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[4,196],f32>, !torch.vtensor<[196,196],si64>) -> !torch.vtensor<[4,196,196],f32> 
    %366 = torch.operator "onnx.Add"(%363, %365) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[4,196,196],f32>) -> !torch.vtensor<[?,4,196,196],f32> 
    %367 = torch.operator "onnx.Softmax"(%366) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[?,4,196,196],f32>) -> !torch.vtensor<[?,4,196,196],f32> 
    %368 = torch.operator "onnx.MatMul"(%367, %360) : (!torch.vtensor<[?,4,196,196],f32>, !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,4,196,?],f32> 
    %369 = torch.operator "onnx.Transpose"(%368) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,4,196,?],f32>) -> !torch.vtensor<[?,196,4,?],f32> 
    %370 = torch.operator "onnx.Reshape"(%369, %215) : (!torch.vtensor<[?,196,4,?],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> 
    %371 = torch.operator "onnx.HardSigmoid"(%370) {torch.onnx.alpha = 0.166666672 : f32} : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> 
    %372 = torch.operator "onnx.Mul"(%370, %371) : (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> 
    %373 = torch.operator "onnx.MatMul"(%372, %220) : (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[128,128],f32>) -> !torch.vtensor<[?,?,128],f32> 
    %374 = torch.operator "onnx.Flatten"(%373) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[?,?,128],f32>) -> !torch.vtensor<[?,128],f32> 
    %375 = torch.operator "onnx.BatchNormalization"(%374, %22, %23, %24, %25) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 0.899999976 : f32} : (!torch.vtensor<[?,128],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32>, !torch.vtensor<[128],f32>) -> !torch.vtensor<[?,128],f32> 
    %376 = torch.operator "onnx.Shape"(%373) : (!torch.vtensor<[?,?,128],f32>) -> !torch.vtensor<[3],si64> 
    %377 = torch.operator "onnx.Reshape"(%375, %376) : (!torch.vtensor<[?,128],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,128],f32> 
    %378 = torch.operator "onnx.Add"(%arg1, %377) : (!torch.vtensor<[1,196,128],f32>, !torch.vtensor<[?,?,128],f32>) -> !torch.vtensor<[?,196,128],f32> 
    %379 = torch.operator "onnx.MatMul"(%378, %221) : (!torch.vtensor<[?,196,128],f32>, !torch.vtensor<[128,256],f32>) -> !torch.vtensor<[?,196,256],f32> 
    %380 = torch.operator "onnx.Flatten"(%379) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[?,196,256],f32>) -> !torch.vtensor<[?,256],f32> 
    %381 = torch.operator "onnx.BatchNormalization"(%380, %26, %27, %28, %29) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 0.899999976 : f32} : (!torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>) -> !torch.vtensor<[?,256],f32> 
    %382 = torch.operator "onnx.Shape"(%379) : (!torch.vtensor<[?,196,256],f32>) -> !torch.vtensor<[3],si64> 
    %383 = torch.operator "onnx.Reshape"(%380, %382) : (!torch.vtensor<[?,256],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,196,256],f32> 
    return %366 : !torch.vtensor<[?,4,196,196],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__Reshape_1235: "0x080000000100000000000000C4000000000000000400000000000000FFFFFFFFFFFFFFFF",
      _onnx__Reshape_1239: "0x080000000100000000000000C4000000000000008000000000000000",
      __7: "0x08000000100000000000000010000000000000002000000000000000",
      __8: "0x080000000000803E"
    }
  }
#-}

getting assertion as

iree-compile: /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/PointerUnion.h:156: T llvm::PointerUnion<mlir::Attribute, mlir::Value>::get() const [PT = <mlir::Attribute, mlir::Value>, T = mlir::Value]: Assertion `isa<T>(*this) && "Invalid accessor called"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.  Program arguments: iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host model.torch_onnx.mlir -o abc.vmfb
 #0 0x00007f01a07c11f7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:13
 #1 0x00007f01a07bf430 llvm::sys::RunSignalHandlers() /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Signals.cpp:106:18
 #2 0x00007f01a07c18ba SignalHandler(int) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #3 0x00007f019a9cb520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f019aa1f9fc __pthread_kill_implementation ./nptl/./nptl/pthread_kill.c:44:76
 #5 0x00007f019aa1f9fc __pthread_kill_internal ./nptl/./nptl/pthread_kill.c:78:10
 #6 0x00007f019aa1f9fc pthread_kill ./nptl/./nptl/pthread_kill.c:89:10
 #7 0x00007f019a9cb476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007f019a9b17f3 abort ./stdlib/./stdlib/abort.c:81:7
 #9 0x00007f019a9b171b _nl_load_domain ./intl/./intl/loadmsgcat.c:1177:9
#10 0x00007f019a9c2e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x00007f01a4fd6a57 mlir::Operation::getOpResultImpl(unsigned int) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:1006:5
#12 0x00007f01a4fd6a57 mlir::Operation::getResult(unsigned int) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:402:54
#13 0x00007f01a4fd6a57 mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<mlir::arith::ConstantOp>::getResult() /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/OpDefinition.h:699:33
#14 0x00007f01a4fd6a57 mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<mlir::arith::ConstantOp>::operator mlir::Value() /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/OpDefinition.h:704:54
#15 0x00007f01a4fd6a57 mlir::inferExpandShapeOutputShape(mlir::OpBuilder&, mlir::Location, mlir::ShapedType, llvm::ArrayRef<llvm::SmallVector<long, 2u> >, llvm::ArrayRef<mlir::OpFoldResult>) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Arith/Utils/Utils.cpp:71:9
#16 0x00007f01a4ea59be std::_Optional_base_impl<llvm::SmallVector<mlir::OpFoldResult, 6u>, std::_Optional_base<llvm::SmallVector<mlir::OpFoldResult, 6u>, false, false> >::_M_is_engaged() const /bin/../lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:471:58
#17 0x00007f01a4ea59be std::optional<llvm::SmallVector<mlir::OpFoldResult, 6u> >::operator bool() const /bin/../lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:985:22
#18 0x00007f01a4ea59be mlir::tensor::ExpandShapeOp::inferOutputShape(mlir::OpBuilder&, mlir::Location, mlir::RankedTensorType, llvm::ArrayRef<llvm::SmallVector<long, 2u> >, llvm::ArrayRef<mlir::OpFoldResult>) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp:1683:8
#19 0x00007f01a4ea5e33 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:130:46
#20 0x00007f01a4ea5e33 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:135:49
#21 0x00007f01a4ea5e33 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:487:42
#22 0x00007f01a4ea5e33 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:588:9
#23 0x00007f01a4ea5e33 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:1198:19
#24 0x00007f01a4ea5e33 mlir::tensor::ExpandShapeOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Type, mlir::Value, llvm::ArrayRef<llvm::SmallVector<long, 2u> >) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp:1707:29
#25 0x00007f01a16ca543 mlir::tensor::ExpandShapeOp mlir::OpBuilder::create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType&, mlir::Value, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&>(mlir::Location, mlir::RankedTensorType&, mlir::Value&&, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/Builders.h:518:16
#26 0x00007f01a16c20a3 mlir::tensor::ExpandShapeOp mlir::RewriterBase::replaceOpWithNewOp<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType&, mlir::Value, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&>(mlir::Operation*, mlir::RankedTensorType&, mlir::Value&&, llvm::SmallVector<llvm::SmallVector<long, 2u>, 1u>&) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/IR/PatternMatch.h:544:5
#27 0x00007f01a16c20a3 (anonymous namespace)::ConvertAtenUnsqueezeOp::matchAndRewrite(mlir::torch::Torch::AtenUnsqueezeOp, mlir::torch::Torch::AtenUnsqueezeOpAdaptor, mlir::ConversionPatternRewriter&) const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/torch-mlir/lib/Conversion/TorchToLinalg/DataMovement.cpp:1753:14
#28 0x00007f01a16cd3ce mlir::OpConversionPattern<mlir::torch::Torch::AtenUnsqueezeOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h:615:3
#29 0x00007f01a4cec0c2 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1682:10
#30 0x00007f01a4d2c4ee mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:212:13
#31 0x00007f01a4d2c4ee void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2>(long) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12
#32 0x00007f01a4d2952f mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:233:9
#33 0x00007f01a4cecfa9 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:0:0
#34 0x00007f01a4cec127 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:0:0
#35 0x00007f01a4ced1af llvm::LogicalResult::failed() const /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/Support/LogicalResult.h:43:43
#36 0x00007f01a4ced1af llvm::failed(llvm::LogicalResult) /proj/xhdhdstaff6/dhirajp/localBuild/iree/third_party/llvm-project/llvm/include/llvm/Support/LogicalResult.h:71:58

and it's coming due to

%378 = torch.operator "onnx.Add"(%arg1, %377) : (!torch.vtensor<[1,196,128],f32>, !torch.vtensor<[?,?,128],f32>) -> !torch.vtensor<[?,196,128],f32> 

Here o/p should have been set to 1,196,128 instead of '?,196,128'

command:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host model.torch_onnx.mlir -o abc.vmfb
pdhirajkumarprasad commented 1 week ago

model list

model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12