google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
326 stars 47 forks source link

heir_translate: `Dialect 'memref' not found for custom op 'memref.alloc'` #954

Open johnmatter opened 2 months ago

johnmatter commented 2 months ago

I am trying to create a pipeline to get from StableHLO to OpenFHE by way of TOSA. It's possible I'm making a newbie mistake, so I'll write up my approach.

Environment

I'm on Ubuntu 24.04. My local build of HEIR is with commit 90bdae81f3ae7320bd00b1e7a2920e63ff0b2ff6. There's no particular reason for being on this commit—that's just the last time I git pulled

Lower StableHLO to TOSA with stablehlo-opt

Here is a sample function that might be part of a CNN:

$ cat in.stablehlo
func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

This mostly lowers to TOSA with stablehlo-opt -stablehlo-prepare-for-tosa in.stablehlo

$ stablehlo-opt -stablehlo-legalize-to-tosa in.stablehlo > in.tosa
$ cat in.tosa
module {
  func.func @main(%arg0: tensor<28x28xf32>, %arg1: tensor<784x10xf32>, %arg2: tensor<1x10xf32>) -> tensor<1x10xf32> {
    %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x10xf32>}> : () -> tensor<1x10xf32>
    %1 = stablehlo.reshape %arg0 : (tensor<28x28xf32>) -> tensor<1x784xf32>
    %2 = tosa.reshape %1 {new_shape = array<i64: 1, 1, 784>} : (tensor<1x784xf32>) -> tensor<1x1x784xf32>
    %3 = tosa.reshape %arg1 {new_shape = array<i64: 1, 784, 10>} : (tensor<784x10xf32>) -> tensor<1x784x10xf32>
    %4 = tosa.matmul %2, %3 : (tensor<1x1x784xf32>, tensor<1x784x10xf32>) -> tensor<1x1x10xf32>
    %5 = tosa.reshape %4 {new_shape = array<i64: 1, 10>} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
    %6 = tosa.add %5, %arg2 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
    %7 = tosa.maximum %6, %0 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
    return %7 : tensor<1x10xf32>
  }
}

It misses the stablehlo.reshape but that's easy enough to replace by hand:

$ cat in.tosa
module {
  func.func @main(%arg0: tensor<28x28xf32>, %arg1: tensor<784x10xf32>, %arg2: tensor<1x10xf32>) -> tensor<1x10xf32> {
    %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x10xf32>}> : () -> tensor<1x10xf32>
    %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 784>} : (tensor<28x28xf32>) -> tensor<1x784xf32>
    %2 = tosa.reshape %1 {new_shape = array<i64: 1, 1, 784>} : (tensor<1x784xf32>) -> tensor<1x1x784xf32>
    %3 = tosa.reshape %arg1 {new_shape = array<i64: 1, 784, 10>} : (tensor<784x10xf32>) -> tensor<1x784x10xf32>
    %4 = tosa.matmul %2, %3 : (tensor<1x1x784xf32>, tensor<1x784x10xf32>) -> tensor<1x1x10xf32>
    %5 = tosa.reshape %4 {new_shape = array<i64: 1, 10>} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
    %6 = tosa.add %5, %arg2 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
    %7 = tosa.maximum %6, %0 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
    return %7 : tensor<1x10xf32>
  }
}

Lower TOSA with heir-opt

$ heir-opt --heir-tosa-to-arith in.tosa > out.mlir

This output is 33k lines long. Lines that include memref include 17k of the forms

    %24298 = affine.load %arg1[1, 0] : memref<784x10xf32, strided<[?, ?], offset: ?>>
    affine.store %24300, %alloc_0[0, 0, 9] : memref<1x1x10xf32>

The other memref lines are:

$ grep memref out.mlir | grep -v affine.store | grep -v affine.load
  func.func @main(%arg0: memref<28x28xf32, strided<[?, ?], offset: ?>>, %arg1: memref<784x10xf32, strided<[?, ?], offset: ?>>, %arg2: memref<1x10xf32, strided<[?, ?], offset: ?>>) -> memref<1x10xf32> {
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<28x28xf32>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x10xf32>
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10xf32>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x10xf32>
    memref.dealloc %alloc : memref<28x28xf32>
    memref.dealloc %alloc_0 : memref<1x1x10xf32>
    memref.dealloc %alloc_1 : memref<1x10xf32>
    return %alloc_2 : memref<1x10xf32>

The remaining lines are arith.add, arith.mulf, arith.constant, etc.

Emit OpenFHE with heir-translate

This is where things fail.

heir-translate --emit-openfhe-pke-header out.mlir
out.mlir:4:14: error: Dialect `memref' not found for custom op 'memref.alloc'
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<28x28xf32>
             ^
out.mlir:4:14: note: Registered dialects: arith, builtin, func, lwe, openfhe, polynomial ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management
asraa commented 2 months ago

Hey! Thanks for such a detailed report. I think what you ran into is expected, as in, nothing wrong on your part, but let me add some comments about the approach. The memref / RLWE pipeline gap is actually something I was working on this week and I'll update you (and at the WG meeting tomorrow).

  1. --heir-tosa-to-arith is one pipeline to lower TOSA to arithmetic operations (which requires bufferization, hence the tensors are lowered to memrefs). This is an OK approach, but note that it was intended originally for experimentation and as a precursor to feeding it into Verilog for the boolean schemes. See point 3 about why it's not ideal for your pipeline.
  2. Currently, your pipeline does not secretize anything - so there's no data that would be considered private and lowered to ciphertext data. We have passes --secretize that secret-izes all the function's inputs (or you can annotate it by hand with an attribute) and --wrap-generic that then converts the program to work with secret data types given the annotations.
  3. The --mlir-to-openfhe-bgv is probably the "right" pass to use (https://github.com/google/heir/blob/aa9d9883406b29067d9a2b68835cd96c5a2f3c5a/tools/heir-opt.cpp#L671) BUT the entrypoint is expected to a program in standard MLIR (affine, arith, func) but with tensors instead of memrefs. Like this example program: https://github.com/google/heir/blob/aa9d9883406b29067d9a2b68835cd96c5a2f3c5a/tests/mlir_to_openfhe_bgv/simple_sum.mlir. This pipeline (and the one you ran) will not actually SIMD-vectorize things nicely, and @lawrencekhlim is currently working on that right now by lowering linalg.matmuls into efficient SIMD operations.

The gap:

If you'll be at the WG meeting tomorrow let's chat! I'll try to see what happens when I lower the TOSA code to the entry IR for --mlir-to-openfhe-bgv.

asraa commented 2 months ago

So back to the original issue described (which honestly, thank you! this serves as good docs as well):

memref was not a supported dialect for the openfhe emitter because we currently expected tensor + affine + arith IRs in the pipelines that emit to openfhe right now.

Of course, we can always add support for memref alloc / loads / stores to the emitter for the sake of compatibility (even if we're not getting nice optimized versions of these programs that use tensor semantics and rotations). WDYT?

asraa commented 2 months ago

This is what I got when I ran your TOSA model through my local linalg to affine with tensor semantics lowering:

func.func @main(%arg0: tensor<28x28xf32>, %arg1: tensor<784x10xf32>, %arg2: tensor<1x10xf32>) -> tensor<1x10xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %0 = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 784>} : (tensor<28x28xf32>) -> tensor<1x1x784xf32>
  %1 = tosa.reshape %arg1 {new_shape = array<i64: 1, 784, 10>} : (tensor<784x10xf32>) -> tensor<1x784x10xf32>
  %2 = tensor.empty() : tensor<1x1x10xf32>
  %3 = affine.for %arg3 = 0 to 10 iter_args(%arg4 = %2) -> (tensor<1x1x10xf32>) {
    %inserted = tensor.insert %cst into %arg4[%c0, %c0, %arg3] : tensor<1x1x10xf32>
    affine.yield %inserted : tensor<1x1x10xf32>
  }
  %4 = affine.for %arg3 = 0 to 10 iter_args(%arg4 = %3) -> (tensor<1x1x10xf32>) {
    %8 = affine.for %arg5 = 0 to 784 iter_args(%arg6 = %arg4) -> (tensor<1x1x10xf32>) {
      %extracted = tensor.extract %0[%c0, %c0, %arg5] : tensor<1x1x784xf32>
      %extracted_0 = tensor.extract %1[%c0, %arg5, %arg3] : tensor<1x784x10xf32>
      %extracted_1 = tensor.extract %3[%c0, %c0, %arg3] : tensor<1x1x10xf32>
      %9 = arith.mulf %extracted, %extracted_0 : f32
      %10 = arith.addf %extracted_1, %9 : f32
      %inserted = tensor.insert %10 into %arg6[%c0, %c0, %arg3] : tensor<1x1x10xf32>
      affine.yield %inserted : tensor<1x1x10xf32>
    }
    affine.yield %8 : tensor<1x1x10xf32>
  }
  %5 = tosa.reshape %4 {new_shape = array<i64: 1, 10>} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
  %6 = tensor.empty() : tensor<1x10xf32>
  %7 = affine.for %arg3 = 0 to 10 iter_args(%arg4 = %6) -> (tensor<1x10xf32>) {
    %extracted = tensor.extract %5[%c0, %arg3] : tensor<1x10xf32>
    %extracted_0 = tensor.extract %arg2[%c0, %arg3] : tensor<1x10xf32>
    %8 = arith.addf %extracted, %extracted_0 : f32
    %9 = arith.maximumf %8, %cst : f32
    %inserted = tensor.insert %9 into %arg4[%c0, %arg3] : tensor<1x10xf32>
    affine.yield %inserted : tensor<1x10xf32>
  }
  return %7 : tensor<1x10xf32>
}

Note: it didn't handle the tosa.reshape but i think you can handle that with tensor.reshape's.

You should be able now to pass this to --mlir-to-openfhe-bgv. You will probably not get an error that you can't heir-translate the tensor.reshape, but you will definitely get down to a nicer IR! Let's add support for that op.

johnmatter commented 2 months ago

Thank you so much for such a quick and thorough response! I was unable to attend the WG this morning unfortunately. Is it every other Wednesday at 5AM Pacific?

I am able to run heir-opt --mlir-to-openfhe-bgv='entry-function=simple_sum ciphertext-degree=32' simple_sum.mlir for the example program you linked. That's encouraging!

When I try to run the same for the IR you provided in your last comment, I get a core dump that I'm not sure how to parse. It could be an LLVM problem more than an HEIR one, I'm not sure. I get similar output if I replace tosa.reshape with tensor.reshape. (Here are some code search results for future readers to consult.) Apologies if there's an obvious mistake I'm making here.

$ heir-opt --mlir-to-openfhe-bgv='entry-function=main ciphertext-degree=32' lowered_to_good_dialects.mlir
heir-opt: external/llvm-project/mlir/lib/IR/PatternMatch.cpp:181: mlir::RewriterBase::eraseOp(mlir::Operation*)::<lambda(mlir::Operation*)>: Assertion `mayBeGraphRegion(*op->getParentRegion()) && "expected that op has no uses"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: heir-opt "--mlir-to-openfhe-bgv=entry-function=main ciphertext-degree=32" lowered_to_good_dialects.mlir
 #0 0x00005e945df8ec70 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proc/self/cwd/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:22
 #1 0x00005e945df8f048 PrintStackTraceSignalHandler(void*) /proc/self/cwd/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:798:1
 #2 0x00005e945df8c424 llvm::sys::RunSignalHandlers() /proc/self/cwd/external/llvm-project/llvm/lib/Support/Signals.cpp:105:20
 #3 0x00005e945df8e539 SignalHandler(int) /proc/self/cwd/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #4 0x00007c6034845320 (/lib/x86_64-linux-gnu/libc.so.6+0x45320)
 #5 0x00007c603489eb1c pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x9eb1c)
 #6 0x00007c603484526e raise (/lib/x86_64-linux-gnu/libc.so.6+0x4526e)
 #7 0x00007c60348288ff abort (/lib/x86_64-linux-gnu/libc.so.6+0x288ff)
 #8 0x00007c603482881b (/lib/x86_64-linux-gnu/libc.so.6+0x2881b)
 #9 0x00007c603483b507 (/lib/x86_64-linux-gnu/libc.so.6+0x3b507)
#10 0x00005e945dc50c00 mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda'(mlir::Operation*)::operator()(mlir::Operation*) const /proc/self/cwd/external/llvm-project/mlir/lib/IR/PatternMatch.cpp:184:43
#11 0x00005e945dc510ef mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)::operator()(mlir::Operation*) const /proc/self/cwd/external/llvm-project/mlir/lib/IR/PatternMatch.cpp:228:3
#12 0x00005e945dc52928 void std::__invoke_impl<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>(std::__invoke_other, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*&&) /usr/include/c++/13/bits/invoke.h:61:67
#13 0x00005e945dc527b1 std::enable_if<is_invocable_r_v<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>, void>::type std::__invoke_r<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>(mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*&&) /usr/include/c++/13/bits/invoke.h:117:5
#14 0x00005e945dc5261e std::_Function_handler<void (mlir::Operation*), mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)>::_M_invoke(std::_Any_data const&, mlir::Operation*&&) /usr/include/c++/13/bits/std_function.h:291:44
#15 0x00005e945d0f22fb std::function<void (mlir::Operation*)>::operator()(mlir::Operation*) const /usr/include/c++/13/bits/std_function.h:591:66
#16 0x00005e945dc50e8d mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)::operator()(mlir::Operation*) const /proc/self/cwd/external/llvm-project/mlir/lib/IR/PatternMatch.cpp:210:76
#17 0x00005e945dc52928 void std::__invoke_impl<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>(std::__invoke_other, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*&&) /usr/include/c++/13/bits/invoke.h:61:67
#18 0x00005e945dc527b1 std::enable_if<is_invocable_r_v<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>, void>::type std::__invoke_r<void, mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*>(mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)&, mlir::Operation*&&) /usr/include/c++/13/bits/invoke.h:117:5
#19 0x00005e945dc5261e std::_Function_handler<void (mlir::Operation*), mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda0'(mlir::Operation*)>::_M_invoke(std::_Any_data const&, mlir::Operation*&&) /usr/include/c++/13/bits/std_function.h:291:44
#20 0x00005e945d0f22fb std::function<void (mlir::Operation*)>::operator()(mlir::Operation*) const /usr/include/c++/13/bits/std_function.h:591:66
#21 0x00005e945dc51280 mlir::RewriterBase::eraseOp(mlir::Operation*) /proc/self/cwd/external/llvm-project/mlir/lib/IR/PatternMatch.cpp:231:1
#22 0x00005e945cfb8523 mlir::heir::secret::RemoveUnusedGenericArgs::matchAndRewrite(mlir::heir::secret::GenericOp, mlir::PatternRewriter&) const /proc/self/cwd/lib/Dialect/Secret/IR/SecretPatterns.cpp:156:21
#23 0x00005e9459d9aa6e mlir::detail::OpOrInterfaceRewritePatternBase<mlir::heir::secret::GenericOp>::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/PatternMatch.h:332:3
#24 0x00005e945d64984f mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::'lambda'()::operator()() const /proc/self/cwd/external/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:212:46
#25 0x00005e945d64a2bc void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::'lambda'()>(long) /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:40
#26 0x00005e945a15b102 llvm::function_ref<void ()>::operator()() const /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:62
#27 0x00005e945d64d6c3 void mlir::MLIRContext::executeAction<mlir::ApplyPatternAction, mlir::Pattern const&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pattern const&) /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3
#28 0x00005e945d649fdb mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) /proc/self/cwd/external/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:233:5
#29 0x00005e945d634353 (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() /proc/self/cwd/external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:615:32
#30 0x00005e945d63554b (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::'lambda'()::operator()() const /proc/self/cwd/external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:874:28
#31 0x00005e945d63689b void llvm::function_ref<void ()>::callback_fn<(anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::'lambda'()>(long) /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:40
#32 0x00005e945a15b102 llvm::function_ref<void ()>::operator()() const /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:62
#33 0x00005e945d636607 void mlir::MLIRContext::executeAction<(anonymous namespace)::GreedyPatternRewriteIteration, long&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, long&) /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3
#34 0x00005e945d635820 (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) && /proc/self/cwd/external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:887:3
#35 0x00005e945d6359e2 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) /proc/self/cwd/external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:919:55
#36 0x00005e9459ae9bb5 mlir::applyPatternsAndFoldGreedily(mlir::Operation*, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) /proc/self/cwd/external/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:159:37
#37 0x00005e945d5a9324 (anonymous namespace)::Canonicalizer::runOnOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Transforms/Canonicalizer.cpp:63:37
#38 0x00005e945db0f5f9 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::'lambda'()::operator()() const /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:528:22
#39 0x00005e945db134a6 void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::'lambda'()>(long) /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:40
#40 0x00005e945a15b102 llvm::function_ref<void ()>::operator()() const /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:62
#41 0x00005e945db188cb void mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3
#42 0x00005e945db0fa21 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:533:23
#43 0x00005e945db0fd6f mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:593:15
#44 0x00005e945db11d38 mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:904:40
#45 0x00005e945db11b09 mlir::PassManager::run(mlir::Operation*) /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:884:69
#46 0x00005e945a13ae00 performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:413:13
#47 0x00005e945a13b616 processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPoolInterface*) /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:478:26
#48 0x00005e945a13bd3a mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)::operator()(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:561:25
#49 0x00005e945a13cfbe llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:52
#50 0x00005e945de6c02c llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::operator()(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#51 0x00005e945de6b75f mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) /proc/self/cwd/external/llvm-project/mlir/lib/Support/ToolUtilities.cpp:27:30
#52 0x00005e945a13bed0 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:564:31
#53 0x00005e945a13c1ea mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:13
#54 0x00005e945a13c497 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) /proc/self/cwd/external/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:621:21
#55 0x00005e9459aa5f17 main /proc/self/cwd/tools/heir-opt.cpp:677:26
#56 0x00007c603482a1ca (/lib/x86_64-linux-gnu/libc.so.6+0x2a1ca)
#57 0x00007c603482a28b __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x2a28b)
#58 0x00005e9459aa1375 _start (/home/ubuntu/.cache/bazel/_bazel_ubuntu/3672d07e31740680b9cbc78fe6ac6df6/execroot/heir/bazel-out/k8-dbg/bin/tools/heir-opt+0x48a375)
Aborted (core dumped)
asraa commented 2 months ago

Thank you so much for such a quick and thorough response! I was unable to attend the WG this morning unfortunately. Is it every other Wednesday at 5AM Pacific?

Ah, it's every other Thursday at 8 AM Pacific: https://heir.dev/community/!

I get similar output if I replace tosa.reshape with tensor.reshape. (Here are some code search results for future readers to consult.) Apologies if there's an obvious mistake I'm making here.

Oh, no apology needed :) I was able to repro the issue, and I think it's an issue in the secret.generic canonicalization patterns @j2kun (FYI, but I'll help debug).

This is the IR before --canonicalize runs:

// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
module {
  func.func @main(%arg0: !secret.secret<tensor<28x28xf32>>, %arg1: !secret.secret<tensor<784x10xf32>>, %arg2: !secret.secret<tensor<1x10xf32>>) -> !secret.secret<tensor<1x10xf32>> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = secret.generic ins(%arg0, %arg1, %arg2 : !secret.secret<tensor<28x28xf32>>, !secret.secret<tensor<784x10xf32>>, !secret.secret<tensor<1x10xf32>>) {
    ^bb0(%arg3: tensor<28x28xf32>, %arg4: tensor<784x10xf32>, %arg5: tensor<1x10xf32>):
      %1 = tosa.reshape %arg3 {new_shape = array<i64: 1, 1, 784>} : (tensor<28x28xf32>) -> tensor<1x1x784xf32>
      %2 = tosa.reshape %arg4 {new_shape = array<i64: 1, 784, 10>} : (tensor<784x10xf32>) -> tensor<1x784x10xf32>
      %3 = tensor.empty() : tensor<1x1x10xf32>
      %4 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %3) -> (tensor<1x1x10xf32>) {
        %inserted = tensor.insert %cst into %arg7[%c0, %c0, %arg6] : tensor<1x1x10xf32>
        affine.yield %inserted : tensor<1x1x10xf32>
      }
      %5 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %4) -> (tensor<1x1x10xf32>) {
        %9 = affine.for %arg8 = 0 to 784 iter_args(%arg9 = %arg7) -> (tensor<1x1x10xf32>) {
          %extracted = tensor.extract %1[%c0, %c0, %arg8] : tensor<1x1x784xf32>
          %extracted_0 = tensor.extract %2[%c0, %arg8, %arg6] : tensor<1x784x10xf32>
          %extracted_1 = tensor.extract %4[%c0, %c0, %arg6] : tensor<1x1x10xf32>
          %10 = arith.mulf %extracted, %extracted_0 : f32
          %11 = arith.addf %extracted_1, %10 : f32
          %inserted = tensor.insert %11 into %arg9[%c0, %c0, %arg6] : tensor<1x1x10xf32>
          affine.yield %inserted : tensor<1x1x10xf32>
        }
        affine.yield %9 : tensor<1x1x10xf32>
      }
      %6 = tosa.reshape %5 {new_shape = array<i64: 1, 10>} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
      %7 = tensor.empty() : tensor<1x10xf32>
      %8 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %7) -> (tensor<1x10xf32>) {
        %extracted = tensor.extract %6[%c0, %arg6] : tensor<1x10xf32>
        %extracted_0 = tensor.extract %arg5[%c0, %arg6] : tensor<1x10xf32>
        %9 = arith.addf %extracted, %extracted_0 : f32
        %10 = arith.maximumf %9, %cst : f32
        %inserted = tensor.insert %10 into %arg7[%c0, %arg6] : tensor<1x10xf32>
        affine.yield %inserted : tensor<1x10xf32>
      }
      secret.yield %8 : tensor<1x10xf32>
    } -> !secret.secret<tensor<1x10xf32>>
    return %0 : !secret.secret<tensor<1x10xf32>>
  }
}

and when I run in debug mode I'm seeing the secret RemoveUnusedGenericArgs causing the issue (omitted some lines in the stack trace):

  * Pattern mlir::heir::secret::RemoveUnusedGenericArgs : 'secret.generic -> ()' {
Trying to match "mlir::heir::secret::RemoveUnusedGenericArgs"
<block argument> of type 'tensor<1x10xf32>' at index: 2 has no uses; removing
    ** Modified: 'secret.generic'(0x564ae3005870)
<block argument> of type 'tensor<1x10xf32>' at index: 3 is passed through to yield
    ** Modified: 'func.return'(0x564ae3025270)
    ** Insert  : 'tosa.reshape'(0x564ae3090540)
    ** Insert  : 'tosa.reshape'(0x564ae3024de0)
    ** Insert  : 'affine.for'(0x564ae3024f90)
    ** Insert  : 'tensor.insert'(0x564ae2fc0650)
    ** Insert  : 'affine.yield'(0x564ae30905c0)
    ** Insert  : 'affine.for'(0x564ae3024ec0)
    ** Insert  : 'affine.for'(0x564ae3025110)
    ** Insert  : 'tensor.extract'(0x564ae30922d0)
    ** Insert  : 'tensor.extract'(0x564ae2fc06c0)
    ** Insert  : 'tensor.extract'(0x564ae3090700)
    ** Insert  : 'arith.mulf'(0x564ae3090770)
    ** Insert  : 'arith.addf'(0x564ae30907e0)
    ** Insert  : 'tensor.insert'(0x564ae3090850)
    ** Insert  : 'affine.yield'(0x564ae30908b0)
    ** Insert  : 'affine.yield'(0x564ae3087db0)
    ** Insert  : 'tosa.reshape'(0x564ae3090d10)
    ** Insert  : 'secret.yield'(0x564ae3090d90)
    ** Insert  : 'secret.generic'(0x564ae3037cf0)
    ** Erase   : 'secret.yield'(0x564ae307d680)
heir-opt: external/llvm-project/mlir/lib/IR/PatternMatch.cpp:181: mlir::RewriterBase::eraseOp(mlir::Operation*)::<lambda(mlir::Operation*)>: Assertion `mayBeGraphRegion(*op->getParentRegion()) && "expected that op has no uses"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /usr/local/google/home/asraa/.cache/bazel/_bazel_asraa/dd6a35b586d0a2d65a5a8939e7a649cf/execroot/heir/bazel-out/k8-dbg/bin/tools/heir-opt --mlir-to-openfhe-bgv= /usr/local/google/home/asraa/git/heir/tests/test.mlir --mlir-print-ir-before-all --debug
 #0 0x0000564aad00116c llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /proc/self/cwd/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:22mlir::RewriterBase::eraseOp(mlir::Operation*)::'lambda'(mlir::Operation*)::operator()(mlir::Operation*) const /proc/self/cwd/external/llvm-
...
#21 0x0000564aacce424a mlir::RewriterBase::eraseOp(mlir::Operation*) /proc/self/cwd/external/llvm-project/mlir/lib/IR/PatternMatch.cpp:231:1
#22 0x0000564aac0bf637 mlir::heir::secret::RemoveUnusedGenericArgs::matchAndRewrite(mlir::heir::secret::GenericOp, mlir::PatternRewriter&) const /proc/self/cwd/lib/Dialect/Secret/IR/SecretPatterns.cpp:156:21
#23 0x0000564aa9006c4c mlir::detail::OpOrInterfaceRewritePatternBase<mlir::heir::secret::GenericOp>::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/PatternMatch.h:332:3
johnmatter commented 2 months ago

Ah, it's every other Thursday at 8 AM Pacific: https://heir.dev/community/!

Ohhh I see now that it says Pacific at the bottom of the calendar. I had assumed it showed localized timezones—oops! (I'm on the east coast.) I'll be at the next one

johnmatter commented 1 month ago

Hi @asraa. Has your internal lining/affine code been merged yet?

I have a simpler matmul example at https://github.com/johnmatter/simplemlir that I put together for the sake of reproducibility. It still starts with --heir-tosa-to-arith, however.

I don't necessarily need optimized SIMD operations, for what it's worth. For my purposes, I just need a fairly general end-to-end TOSA to C++ pipeline that compiles. I'm still at proof-of-concept :)

johnmatter commented 1 month ago

Actually, mlir-opt seems to get me from TOSA to affine+arith+tensor. The script lower_mlir-opt.sh in my repo does that and passes the result to heir-opt. Here's the secretize&wrap-genericed output before --mlir-to-openfhe-bgv:

module {
  func.func @main(%arg0: !secret.secret<tensor<2x2xi32>>, %arg1: !secret.secret<tensor<2x2xi32>>) -> !secret.secret<tensor<2x2xi32>> {
    %c0_i32 = arith.constant 0 : i32
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<2x2xi32>>, !secret.secret<tensor<2x2xi32>>) {
    ^bb0(%arg2: tensor<2x2xi32>, %arg3: tensor<2x2xi32>):
      %expanded = tensor.expand_shape %arg2 [[0, 1], [2]] output_shape [1, 2, 2] : tensor<2x2xi32> into tensor<1x2x2xi32>
      %expanded_0 = tensor.expand_shape %arg3 [[0, 1], [2]] output_shape [1, 2, 2] : tensor<2x2xi32> into tensor<1x2x2xi32>
      %1 = tensor.empty() : tensor<1x2x2xi32>
      %2 = linalg.fill ins(%c0_i32 : i32) outs(%1 : tensor<1x2x2xi32>) -> tensor<1x2x2xi32>
      %3 = linalg.batch_matmul ins(%expanded, %expanded_0 : tensor<1x2x2xi32>, tensor<1x2x2xi32>) outs(%2 : tensor<1x2x2xi32>) -> tensor<1x2x2xi32>
      %collapsed = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<1x2x2xi32> into tensor<2x2xi32>
      secret.yield %collapsed : tensor<2x2xi32>
    } -> !secret.secret<tensor<2x2xi32>>
    return %0 : !secret.secret<tensor<2x2xi32>>
  }
}

The openfhe step fails with: error: expected batched secret types to be tensors with dimension matching ring parameter

Is that because I need to flatten/pad the matrices to be 1xN where N is a power of two?