onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
760 stars 319 forks source link

assertion error in the KrnlToLLVM conversion pass krnl.global with 128+ string elements. #2045

Closed negiyas closed 1 year ago

negiyas commented 1 year ago

The following assertion error occurs in the KrnlToLLVM conversion pass with the attached krnl.global op. It seems that the logic for a large array to use "denseAttr.getRawData()" cannot support string type. We avoid this issue by adding a condition to check it the type is not string (e.g. if ((!denseAttr.isSplat()) && (sizeInBytes > 1024)) { to if ((!denseAttr.getElementType().isa<StringType>()) && (!denseAttr.isSplat()) && (sizeInBytes > 1024)) { ).

However the compilation time will be increased especially for models with large category map tables (e.g. Bidaf-9), because many string constant ops are generated.

Assertion error

onnx-mlir-opt: /home1/negishi/src/dlc.git/onnx-mlir.opt/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp:161: mlir::LLVM::GlobalOp onnx_mlir::krnl::KrnlGlobalOpLowering::lowerDenseConstant(mlir::KrnlGlobalOp&, mlir::Type, mlir::ConversionPatternRewriter&) const: Assertion `((int64_t)rawData.size() == sizeInBytes) && "Data size mismatch."' failed.

Sample input

// Test CategorMapper lowering when the input is a list of strings.
func.func private @test_category_mapper3_string_to_int64() -> memref<129x!krnl.string> {
  %4 = "krnl.global"() {name = "cats_strings", shape = [129], value = dense<["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", "43", "44", "45", "46", "47", "48", "49", "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "60", "61", "62", "63", "64", "65", "66", "67", "68", "69", "70", "71", "72", "73", "74", "75", "76", "77", "78", "79", "80", "81", "82", "83", "84", "85", "86", "87", "88", "89", "90", "91", "92", "93", "94", "95", "96", "97", "98", "99", "100", "101", "102", "103", "104", "105", "106", "107", "108", "109", "110", "111", "112", "113", "114", "115", "116", "117", "118", "119", "120", "121", "122", "123", "124", "125", "126", "127", "128", "129"]> : tensor<129x!krnl.string>} : () -> memref<129x!krnl.string>
  return %4 : memref<129x!krnl.string>
}

Conversion code

  LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType,
      ConversionPatternRewriter &rewriter) const {
    assert(krnlGlobalOp.getValue().has_value() &&
           "Expecting KrnlGlobalOp with a valid value");
    assert(krnlGlobalOp.getValue().value().isa<DenseElementsAttr>() &&
           "Expecting a global with an dense elements attribute");

    MLIRContext *context = krnlGlobalOp.getContext();
    Location loc = krnlGlobalOp.getLoc();
    ModuleOp module = krnlGlobalOp->getParentOfType<ModuleOp>();
    MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

    OpBuilder::InsertionGuard insertGuard(rewriter);
    rewriter.setInsertionPointToStart(module.getBody());

    DenseElementsAttr denseAttr =
        krnlGlobalOp.getValue().value().cast<DenseElementsAttr>();

    int64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp);
    LLVM::GlobalOp global;
    if ((!denseAttr.isSplat()) && (sizeInBytes > 1024)) {
      ArrayRef<char> rawData = denseAttr.getRawData();
      assert(((int64_t)rawData.size() == sizeInBytes) && "Data size mismatch.");

      StringRef data(rawData.data(), rawData.size());
      StringAttr llvmStringAttr = StringAttr::get(context, data);
      auto llvmArrayI8Ty =
          LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes);
      global = create.llvm.globalOp(llvmArrayI8Ty,
          /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
          llvmStringAttr);
    } else {
      if (denseAttr.getElementType().isa<StringType>())
        global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter);
      else
        global = create.llvm.globalOp(globalType,
            /*isConstant=*/true, LLVM::Linkage::Internal,
            krnlGlobalOp.getName(), krnlGlobalOp.getValue().value());
    }

    LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
    return global;
  }
negiyas commented 1 year ago

Fixed by PR#2055.