google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
350 stars 50 forks source link

heir-simd-vectorizer: dot product example for ckks incorrectly transformed #1115

Open ZenithalHourlyRate opened 4 days ago

ZenithalHourlyRate commented 4 days ago

I tried to migrate the bgv dot product example to CKKS, and the result is incorrect. After inspection, it seems that the transformation that heir-simd-vectorizer has done is incorrect.

The input for the pipeline is the belowing; just the bgv example with all i16 substituted with f16.

func.func @dot_product(%arg0: tensor<8xf16>, %arg1: tensor<8xf16>) -> f16 {
  %c0 = arith.constant 0 : index
  %c0_sf16 = arith.constant 0.0 : f16
  %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_sf16) -> (f16) {
    %1 = tensor.extract %arg0[%arg2] : tensor<8xf16>
    %2 = tensor.extract %arg1[%arg2] : tensor<8xf16>
    %3 = arith.mulf %1, %2 : f16
    %4 = arith.addf %iter, %3 : f16
    affine.yield %4 : f16
  }
  return %0 : f16
}

After running --mlir-to-secret-arithmetic="entry-function=dot_product", we get

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xf16>>, %arg1: !secret.secret<tensor<8xf16>>) -> !secret.secret<f16> {
    %c6 = arith.constant 6 : index
    %cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xf16>>, !secret.secret<tensor<8xf16>>) {
    ^bb0(%arg2: tensor<8xf16>, %arg3: tensor<8xf16>):
      %1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
      %2 = arith.addf %1, %cst : tensor<8xf16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
      %5 = arith.addf %3, %4 : tensor<8xf16>
      %6 = arith.addf %5, %1 : tensor<8xf16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xf16>, index
      %8 = arith.addf %7, %4 : tensor<8xf16>
      %9 = arith.addf %8, %1 : tensor<8xf16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xf16>, index
      %11 = arith.addf %10, %4 : tensor<8xf16>
      %12 = arith.addf %11, %1 : tensor<8xf16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xf16>, index
      %14 = arith.addf %13, %1 : tensor<8xf16>
      %extracted = tensor.extract %14[%c7] : tensor<8xf16>
      secret.yield %extracted : f16
    } -> !secret.secret<f16>
    return %0 : !secret.secret<f16>
  }
}

It is apparently different from the result of dot product for bgv, where a rotate-and-reduce pattern is working.

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xi16>>, %arg1: !secret.secret<tensor<8xi16>>) -> !secret.secret<i16> {
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xi16>>, !secret.secret<tensor<8xi16>>) {
    ^bb0(%arg2: tensor<8xi16>, %arg3: tensor<8xi16>):
      %1 = arith.muli %arg2, %arg3 : tensor<8xi16>
      %2 = tensor_ext.rotate %1, %c4 : tensor<8xi16>, index
      %3 = arith.addi %1, %2 : tensor<8xi16>
      %4 = tensor_ext.rotate %3, %c2 : tensor<8xi16>, index
      %5 = arith.addi %3, %4 : tensor<8xi16>
      %6 = tensor_ext.rotate %5, %c1 : tensor<8xi16>, index
      %7 = arith.addi %5, %6 : tensor<8xi16>
      %extracted = tensor.extract %7[%c7] : tensor<8xi16>
      secret.yield %extracted : i16
    } -> !secret.secret<i16>
    return %0 : !secret.secret<i16>
  }
}

If we execute the emitted ckks code with input (1, 2, 3, 4, 1, 2, 3, 4), we get incorrect result, with traces like this:

v5=EvalMultNoRelin v1, v2
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v6=Relinearize v5
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v8=EvalAdd v6, v7
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v9=EvalRotate v8, 6
result decrypted: (9, 16, -2.35721e-15, 1.85306e-15, 3.1336e-15, 5.86141e-16, -2.08615e-15, 2.21339e-15,  ... ); Estimated precision: 47 bits

v10=EvalRotate v6, 7
result decrypted: (16, -5.25744e-15, 6.75882e-15, 5.41846e-16, 7.02148e-15, -4.51887e-15, 4.41042e-15, -3.54915e-15,  ... ); Estimated precision: 47 bits

v11=EvalAdd v9, v10
result decrypted: (25, 16, 3.80997e-16, 3.49924e-15, -1.00012e-16, -6.83104e-15, -1.55568e-15, -3.20682e-15,  ... ); Estimated precision: 47 bits

v12=EvalAdd v11, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v13=EvalRotate v12, 6
result decrypted: (9, 16, -6.82527e-15, -7.59157e-16, 1.27657e-15, -1.48354e-15, 1.24327e-15, 9.49766e-16,  ... ); Estimated precision: 47 bits

v14=EvalAdd v13, v10
result decrypted: (25, 16, 3.2116e-15, 1.26548e-15, -4.74639e-15, -5.8531e-15, 1.28463e-15, 6.89365e-15,  ... ); Estimated precision: 47 bits

v15=EvalAdd v14, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v16=EvalRotate v15, 6
result decrypted: (9, 16, -3.58021e-15, 3.47007e-15, -2.34613e-15, -1.81569e-15, -5.64447e-16, -2.51926e-17,  ... ); Estimated precision: 47 bits

v17=EvalAdd v16, v10
result decrypted: (25, 16, 3.26906e-15, -8.31767e-16, 2.92677e-15, -2.5431e-15, -1.33646e-15, -2.1228e-15,  ... ); Estimated precision: 47 bits

v18=EvalAdd v17, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v19=EvalRotate v18, 7
result decrypted: (16, 3.52226e-16, 2.59192e-15, 1.11261e-15, -3.50525e-15, 2.17398e-15, -4.26154e-15, -8.52218e-16,  ... ); Estimated precision: 47 bits

v20=EvalAdd v19, v6
result decrypted: (17, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v22=EvalMult v20, v21
result decrypted: (-4.38183e-16, -2.34426e-15, 3.01672e-15, 2.99825e-15, -2.78687e-15, -1.29908e-15, -3.58669e-15, 16,  ... ); Estimated precision: 47 bits

v23=EvalRotate v22, 7
result decrypted: (16, -1.8405e-15, -3.56376e-15, -2.69961e-15, -4.74316e-15, -4.06186e-15, 5.39773e-15, -3.67044e-15,  ... ); Estimated precision: 47 bits

Expected: 60
Actual: (16,  ... ); Estimated precision: 47 bits
j2kun commented 4 days ago

How concerning! I don't immediately see why the rotate-and-reduce pass isn't properly handling the mulf/addf ops. Maybe you could run with --mlir-print-ir-after-all and --mlir-print-ir-tree-dir to print out the IR after each pass in the pipeline, and then you could compare the bgv/ckks versions to see the first pass at which they meaningfully differ.

I suspect the use of addf/mulf is triggering some incorrect match that causes some pattern or pass to not be applied. However, that shouldn't (absent bugs) cause the output to produce an incorrect result, it should just produce a less efficient program.

So looking for other reasons the output might be incorrect: I see the line above after rotating by 6

result decrypted: (9, 16, -2.35721e-15, 1.85306e-15, 3.1336e-15, 5.86141e-16, -2.08615e-15, 2.21339e-15,  ... )

I believe those zero values should be nonzero. In particular, I recall the encoding used by the simd-vectorizer passes expects 1D tensors to be repeated to fill up the available ciphertext space, since the rotations analyzed are cyclic mod 8 (in your example, because it's a tensor<8xf16>) but the openfhe backend uses larger ciphertext sizes. Cf. https://github.com/google/heir/blob/51ebc8dd6eb2d8c431ee20d25a0fb5ec016aa550/tests/Examples/openfhe/simple_sum_test.cpp#L34-L42 and https://github.com/google/heir/issues/645

Could that be causing the incorrectness?

ZenithalHourlyRate commented 4 days ago

I recall the encoding used by the simd-vectorizer passes expects 1D tensors to be repeated to fill up the available ciphertext space

Oh it is exactly this reason, after copying the cyclic filling code from test/Examples, the result is correct.

  std::vector<double> x1 = {1, 2, 3, 4, 1, 2, 3, 4};

  int32_t n =
      cryptoContext->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2;
  std::vector<double> outputs;
  outputs.reserve(n);
  for (int i = 0; i < n; ++i) {
    outputs.push_back(x1[i % 8]);
  }
  const auto& ptxt1 = cryptoContext->MakeCKKSPackedPlaintext(outputs);
  const auto& c1 = cryptoContext->Encrypt(keyPair.publicKey, ptxt1);
ZenithalHourlyRate commented 4 days ago

run with --mlir-print-ir-after-all and --mlir-print-ir-tree-dir

The difference is that the first apply-folder after full-loop-unroll does eliminate the constant 0 (i.e. %c0_si16 = arith.constant 0 : i16) for BGV as it does is a equivalent transform but does not eliminate the constant 0.0; then rotate-and-reduce can not recognize such pattern (I thought I have supported this pattern in past PRs?).

If I manually delete the constant after full-loop-unroll and run the passes again, a correct rotate-and-reduce version will come out.

ZenithalHourlyRate commented 4 days ago

I thought I have supported this pattern in past PRs?

The current pattern is

%cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
%1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
%2 = arith.addf %1, %cst : tensor<8xf16>
%3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
%4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
// mixed using of %1 and %2 afterwards.

Which means %3 and %4 does not have the same root in rotation-analysis.

The constant tensor could be saved so that %3/%4 can have a same root, like mentioned in #522, but it is hard to handle, as the constant tensor gets rotated later. I avoided handling saving tensors in earlier PRs. And for this specific case, I think apply-folders should eliminate this constant tensor.

https://github.com/google/heir/blob/51ebc8dd6eb2d8c431ee20d25a0fb5ec016aa550/lib/Analysis/RotationAnalysis/RotationAnalysis.h#L218-L224

Full IR below

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xf16>>, %arg1: !secret.secret<tensor<8xf16>>) -> !secret.secret<f16> {
    %c7 = arith.constant 7 : index
    %cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
    %c6 = arith.constant 6 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xf16>>, !secret.secret<tensor<8xf16>>) {
    ^bb0(%arg2: tensor<8xf16>, %arg3: tensor<8xf16>):
      %1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
      %2 = arith.addf %1, %cst : tensor<8xf16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
      %5 = arith.addf %3, %4 : tensor<8xf16>
      %6 = arith.addf %5, %1 : tensor<8xf16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xf16>, index
      %8 = arith.addf %7, %4 : tensor<8xf16>
      %9 = arith.addf %8, %1 : tensor<8xf16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xf16>, index
      %11 = arith.addf %10, %4 : tensor<8xf16>
      %12 = arith.addf %11, %1 : tensor<8xf16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xf16>, index
      %14 = arith.addf %13, %1 : tensor<8xf16>
      %extracted = tensor.extract %14[%c7] : tensor<8xf16>
      secret.yield %extracted : f16
    } -> !secret.secret<f16>
    return %0 : !secret.secret<f16>
  }
}
ZenithalHourlyRate commented 4 days ago

Or this case migrated to BGV, dot product having an initial non-zero sum, and rotate-and-reduce can not handle such thing correctly. Should we handle this in insert-rotate so that inserted-rotations can have a same root? Now the mixed using of %1 and %2 is not friendly

func.func @dot_product(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> i16 {
  %c0 = arith.constant 0 : index
  %c0_si16 = arith.constant 10 : i16
  %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) {
    %1 = tensor.extract %arg0[%arg2] : tensor<8xi16>
    %2 = tensor.extract %arg1[%arg2] : tensor<8xi16>
    %3 = arith.muli %1, %2 : i16
    %4 = arith.addi %iter, %3 : i16
    affine.yield %4 : i16
  }
  return %0 : i16
}

We get

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xi16>>, %arg1: !secret.secret<tensor<8xi16>>) -> !secret.secret<i16> {
    %c6 = arith.constant 6 : index
    %cst = arith.constant dense<10> : tensor<8xi16>
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xi16>>, !secret.secret<tensor<8xi16>>) {
    ^bb0(%arg2: tensor<8xi16>, %arg3: tensor<8xi16>):
      %1 = arith.muli %arg2, %arg3 : tensor<8xi16>
      %2 = arith.addi %1, %cst : tensor<8xi16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xi16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xi16>, index
      %5 = arith.addi %3, %4 : tensor<8xi16>
      %6 = arith.addi %5, %1 : tensor<8xi16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xi16>, index
      %8 = arith.addi %7, %4 : tensor<8xi16>
      %9 = arith.addi %8, %1 : tensor<8xi16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xi16>, index
      %11 = arith.addi %10, %4 : tensor<8xi16>
      %12 = arith.addi %11, %1 : tensor<8xi16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xi16>, index
      %14 = arith.addi %13, %1 : tensor<8xi16>
      %extracted = tensor.extract %14[%c7] : tensor<8xi16>
      secret.yield %extracted : i16
    } -> !secret.secret<i16>
    return %0 : !secret.secret<i16>
  }
}
j2kun commented 3 days ago

I'm not sure I fully understand what you're suggesting, but let me try to repeat:

does not eliminate the constant 0.0

If we could figure out why the floating point constant is not folded away, we could solve the immediate problem, but I think in your example with a non-zero initial value of the dot product, this problem would persist in another form. It could maybe be fixed by changing insert-rotate so that it aligns things properly, or maybe it could be changed in rotate-and-reduce to recognize a reduction in a smarter way than looking at a single linear chain (#522).

I support either of those improvements. I think also https://github.com/google/heir/issues/521 might allow a workaround wherein the rotate-and-reduce works for the entire vector except that first element in the chain, which would be nearly optimal and give other side benefits to IRs that don't do complete reductions..