google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
285 stars 40 forks source link

SecretDistributeGeneric Bug on affine.for and scf.for operations #784

Closed MeronZerihun closed 1 month ago

MeronZerihun commented 1 month ago

This code snippet is taken from tests/openfhe/end_to_end/simple_sum.mlir

func.func @simple_sum(%arg0: tensor<32xi16>) -> i16 {
  %c0 = arith.constant 0 : index
  %c0_si16 = arith.constant 0 : i16
  %0 = affine.for %i = 0 to 32 iter_args(%sum_iter = %c0_si16) -> i16 {
    %1 = tensor.extract %arg0[%i] : tensor<32xi16>
    %2 = arith.addi %1, %sum_iter : i16
    affine.yield %2 : i16
  }
  return %0 : i16
}

Running this command crashes, not sure why:

$ heir-opt --secretize=entry-function=simple_sum --wrap-generic --secret-distribute-generic tests/openfhe/end_to_end/simple_sum.mlir 

Replacing the affine.for with scf.for produces a return-type mismatch error that looks like the following:

tests/openfhe/end_to_end/simple_sum.mlir:15:8: error: 'scf.for' op 0-th region iter_arg and 0-th yielded value have different type: 'i16' != '!secret.secret<i16>'
  %0 = scf.for %i = %lower to %upper step %c1 iter_args(%sum_iter = %c0_si16) -> i16 {
       ^
tests/openfhe/end_to_end/simple_sum.mlir:15:8: note: see current operation: 
%4 = "scf.for"(%3, %1, %2, %0) ({
^bb0(%arg1: index, %arg2: i16):
  %5 = "secret.generic"(%arg0) ({
  ^bb0(%arg4: tensor<32xi16>):
    %8 = "tensor.extract"(%arg4, %arg1) : (tensor<32xi16>, index) -> i16
    "secret.yield"(%8) : (i16) -> ()
  }) : (!secret.secret<tensor<32xi16>>) -> !secret.secret<i16>
  %6 = "secret.generic"(%5) ({
  ^bb0(%arg3: i16):
    %7 = "arith.addi"(%arg3, %arg2) <{overflowFlags = #arith.overflow<none>}> : (i16, i16) -> i16
    "secret.yield"(%7) : (i16) -> ()
  }) : (!secret.secret<i16>) -> !secret.secret<i16>
  "scf.yield"(%6) : (!secret.secret<i16>) -> ()
}) : (index, index, index, i16) -> !secret.secret<i16>

Any idea why? @j2kun @asraa

AlexanderViand-Intel commented 1 month ago

Note: the crash in the first case is due to this assert in DistributeGeneric.cpp:

assert(isa<BlockArgument>(value) &&  "not sure what to do here, file a bug");

See https://github.com/google/heir/blob/main/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp#L283-L285

asraa commented 1 month ago

Replacing the affine.for with scf.for produces a return-type mismatch error that looks like the following:

I think this seems to be because the secret.generic is returning a secret type, but the iter_args is a plain integer value. It seems like we should upgrade the initial iter_args value to be secret.cast'ed into a secret type.

Let me investigate the first crash!

asraa commented 1 month ago

Likely the same sort of situation needs to happen for the first case too - let me see if I can tinker with it!

j2kun commented 1 month ago

Well at least I left a nice error message 😎