llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.03k stars 11.58k forks source link

[mlir] [tosa] -tosa-to-tensor crashes in TosaToTensor.cpp:195: SmallVector<ReassociationExprs> (anonymous namespace)::createReassociationMapForCollapse(OpBuilder &, Type, Type): Assertion `currSrcDim == srcShape.size() && currDstDim == dstShape.size()' failed. #108151

Open axeabc opened 1 week ago

axeabc commented 1 week ago

git version: 761bf333e378b52614c

system: Ubuntu 18.04.6 LTS

reproduce with: mlir-opt -tosa-to-tensor a.mlir

a.mlir:

module {
  func.func @test_reshape_3d(%arg1: tensor<2x3x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) {
    %c0 = arith.constant 0 : index
    %0 = tosa.reshape %arg1 {new_shape = array<i64: 2, 4, 4>} : (tensor<2x3x4xf32>) -> tensor<2x?x4xf32>
    %1 = tosa.reshape %arg1 {new_shape = array<i64: 2, 1, 4>} : (tensor<2x3x4xf32>) -> tensor<2x?x4xf32>
    return %0, %1 : tensor<2x?x4xf32>, tensor<2x?x4xf32>
  }
}

stack trace:

mlir-opt: /data/szy/MLIR/llvm-release/llvm-project/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp:195: SmallVector<ReassociationExprs> (anonymous namespace)::createReassociationMapForCollapse(OpBuilder &, Type, Type): Assertion `currSrcDim == srcShape.size() && currDstDim == dstShape.size()' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt -tosa-to-tensor a.mlir
 #0 0x000055b5adeb3128 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x10d7128)
 #1 0x000055b5adeb0c3e llvm::sys::RunSignalHandlers() (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x10d4c3e)
 #2 0x000055b5adeb3abd SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f7864715420 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
 #4 0x00007f7863d5200b raise /build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:51:1
 #5 0x00007f7863d31859 abort /build/glibc-LcI20x/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007f7863d31729 get_sysdep_segment_value /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:509:8
 #7 0x00007f7863d31729 _nl_load_domain /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:970:34
 #8 0x00007f7863d42fd6 (/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
 #9 0x000055b5b09c4e77 (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x3be8e77)
#10 0x000055b5b09c3fb7 (anonymous namespace)::ReshapeConverter::matchAndRewrite(mlir::tosa::ReshapeOp, mlir::tosa::ReshapeOpAdaptor, mlir::ConversionPatternRewriter&) const TosaToTensor.cpp:0:0
#11 0x000055b5b09c3350 mlir::OpConversionPattern<mlir::tosa::ReshapeOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x3be7350)
#12 0x000055b5b1003b11 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x4227b11)
#13 0x000055b5b398ec91 void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_0>(long) PatternApplicator.cpp:0:0
#14 0x000055b5b398b94b mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x6baf94b)
#15 0x000055b5b1004b53 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#16 0x000055b5b1003bb7 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x4227bb7)
#17 0x000055b5b1004d7f mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x4228d7f)
#18 0x000055b5b100c96b mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x423096b)
#19 0x000055b5b09bf6c8 (anonymous namespace)::TosaToTensor::runOnOperation() TosaToTensorPass.cpp:0:0
#20 0x000055b5b0fa52d6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41c92d6)
#21 0x000055b5b0fa5c40 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41c9c40)
#22 0x000055b5b0fa8282 mlir::PassManager::run(mlir::Operation*) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41cc282)
#23 0x000055b5b0fa0ab1 performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#24 0x000055b5b0fa070b llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#25 0x000055b5b104d3a5 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x42713a5)
#26 0x000055b5b0f9bb35 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41bfb35)
#27 0x000055b5b0f9bddf mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41bfddf)
#28 0x000055b5b0f9c10e mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x41c010e)
#29 0x000055b5ade93d67 main (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x10b7d67)
#30 0x00007f7863d33083 __libc_start_main /build/glibc-LcI20x/glibc-2.31/csu/../csu/libc-start.c:342:3
#31 0x000055b5ade938ee _start (/data/szy/MLIR/llvm-release/llvm-project/build/bin/mlir-opt+0x10b78ee)
sjarus commented 6 days ago

Thanks for filing this issue. I have a couple of thoughts:

sjarus commented 6 days ago

A small change to the test to be semantically correct in terms of element count, appears to work fine:

module { func.func @test_reshape_3d(%arg1: tensor<2x3x4xf32>) -> (tensor<6x?x4xf32>, tensor<2x?x12xf32>) { %c0 = arith.constant 0 : index %0 = tosa.reshape %arg1 {new_shape = array<i64: 6, 1, 4>} : (tensor<2x3x4xf32>) -> tensor<6x?x4xf32> %1 = tosa.reshape %arg1 {new_shape = array<i64: 2, 1, 12>} : (tensor<2x3x4xf32>) -> tensor<2x?x12xf32> return %0, %1 : tensor<6x?x4xf32>, tensor<2x?x12xf32> } }

With /bin/mlir-opt tmp2.mlir --tosa-to-tensor

module { func.func @test_reshape_3d(%arg0: tensor<2x3x4xf32>) -> (tensor<6x?x4xf32>, tensor<2x?x12xf32>) { %c0 = arith.constant 0 : index %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x3x4xf32> into tensor<6x4xf32> %expanded = tensor.expand_shape %collapsed [[0, 1], [2]] output_shape [6, 1, 4] : tensor<6x4xf32> into tensor<6x1x4xf32> %cast = tensor.cast %expanded : tensor<6x1x4xf32> to tensor<6x?x4xf32> %collapsed_0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xf32> into tensor<2x12xf32> %expanded_1 = tensor.expand_shape %collapsed_0 [[0, 1], [2]] output_shape [2, 1, 12] : tensor<2x12xf32> into tensor<2x1x12xf32> %cast_2 = tensor.cast %expanded_1 : tensor<2x1x12xf32> to tensor<2x?x12xf32> return %cast, %cast_2 : tensor<6x?x4xf32>, tensor<2x?x12xf32> } }

GeorgeARM commented 6 days ago

@sjarus is right; this definitely shouldn't have led to a crash. The ReshapeOp verifier doesn't have a check for this case but is rather straight-forward to handle as the new_shape and the input itself are completely statically defined.

CoTinker commented 1 day ago

This issue is duplicate with #107969. Is it correct to add a verifier for tosa.reshape in this way?

if (inputType.hasStaticShape()) {
    int64_t inputElementsNum = inputType.getNumElements();

    // Compute the number of elements in the new shape
    int64_t newShapeElementsNum = std::accumulate(
        getNewShape().begin(), getNewShape().end(), 1LL, 
        [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });

    // Check if the new shape is fully static
    bool isStaticNewShape = llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });

    // Validate the reshape operation
    if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
        (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
        return emitOpError() << "Cannot reshape " << inputElementsNum
                             << " elements into " << newShapeElementsNum;
    }
}

If it's correct, I'll submit a PR.

GeorgeARM commented 1 day ago

This issue is duplicate with #107969. Is it correct to add a verifier for tosa.reshape in this way?

if (inputType.hasStaticShape()) {
    int64_t inputElementsNum = inputType.getNumElements();

    // Compute the number of elements in the new shape
    int64_t newShapeElementsNum = std::accumulate(
        getNewShape().begin(), getNewShape().end(), 1LL, 
        [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });

    // Check if the new shape is fully static
    bool isStaticNewShape = llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });

    // Validate the reshape operation
    if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
        (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
        return emitOpError() << "Cannot reshape " << inputElementsNum
                             << " elements into " << newShapeElementsNum;
    }
}

If it's correct, I'll submit a PR.

Thanks @CoTinker for having a look. Please feel free to submit a PR and we can review.

CoTinker commented 1 day ago

Okay.