google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
293 stars 44 forks source link

Distributing `secret.generic` across a loop-carried dependence #948

Closed raghav198 closed 6 days ago

raghav198 commented 2 weeks ago

The --secret-distribute-generic pass fails when trying to distribute across a for-loop with a loop-carried dependence. For example:

func.func @sum(%arr: !secret.secret<memref<8xi8>>) -> !secret.secret<i8> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c8 = arith.constant 8 : index
    %0 = secret.generic ins(%arr: !secret.secret<memref<8xi8>>) {
        ^bb0(%ARR: memref<8xi8>):
            %sum_0 = arith.constant 0 : i8
            %sum = scf.for %i = %c0 to %c8 step %c1
                    iter_args(%iter = %sum_0) -> i8 {
                %cur = memref.load %ARR[%i] : memref<8xi8>
                %cur_sum = arith.addi %iter, %cur : i8
                scf.yield %cur_sum : i8
            }
            secret.yield %sum : i8
    } -> !secret.secret<i8>
    return %0 : !secret.secret<i8>
}

Distributing should yield a for-loop that carries a !secret.secret<i8> from iteration to iteration, but it fails with this error message:

error: 'scf.for' op 0-th region iter_arg and 0-th yielded value have different type: 'i8' != '!secret.secret<i8>'

asraa commented 2 weeks ago

Thanks for filing! I'll take a look. I thought I had handled scf.for iter args here: https://github.com/google/heir/pull/792

But there must be something it's not happy with. Thank for posting the test case, I'll use that.

asraa commented 6 days ago

Hey! I just managed to test this out. Do you know what commit you're at for your heir build? I'm having trouble reproducing it. This is what I get on the IR at head:

$ bazel run -c dbg @//heir/tools:heir-opt -- --secret-distribute-generic $(pwd)/third_party/
heir/tests/secret/distribute_loop_vars.mlir
module {
  func.func @sum(%arg0: !secret.secret<memref<8xi8>>) -> !secret.secret<i8> {
    %c0_i8 = arith.constant 0 : i8
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c8 = arith.constant 8 : index
    %0 = secret.conceal %c0_i8 : i8 -> <i8>
    %1 = scf.for %arg1 = %c0 to %c8 step %c1 iter_args(%arg2 = %0) -> (!secret.secret<i8>) {
      %2 = secret.generic ins(%arg0 : !secret.secret<memref<8xi8>>) {
      ^bb0(%arg3: memref<8xi8>):
        %4 = memref.load %arg3[%arg1] : memref<8xi8>
        secret.yield %4 : i8
      } -> !secret.secret<i8>
      %3 = secret.generic ins(%arg2, %2 : !secret.secret<i8>, !secret.secret<i8>) {
      ^bb0(%arg3: i8, %arg4: i8):
        %4 = arith.addi %arg3, %arg4 : i8
        secret.yield %4 : i8
      } -> !secret.secret<i8>
      scf.yield %3 : !secret.secret<i8>
    }
    return %1 : !secret.secret<i8>
  }
}
raghav198 commented 6 days ago

Ahh, I've been working off my fork. I think the most recent commit I pulled was 6b8074e, but its possible something else I did broke it...

asraa commented 6 days ago

Gotcha! I see. That commit is behind the fix PR I linked above, so I think it you rebase your fork you'll be good :)

raghav198 commented 6 days ago

Ahh, great thanks!!