secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
243 stars 106 forks source link

[Bug]: One more minus sign #641

Closed maths644311798 closed 7 months ago

maths644311798 commented 7 months ago

Issue Type

Usability

Modules Involved

MPC protocol

Have you reproduced the bug with SPU HEAD?

No

Have you searched existing issues?

Yes

SPU Version

commit 34d7f30806306dcfadceffd866a6bb3205ff2d84 (HEAD -> main, origin/main, origin/HEAD)

OS Platform and Distribution

Ubuntu 20.04

Python Version

3.10

Compiler Version

gcc 13

Current Behavior?

In mpc/cheetah/rlwe/lwe_ct.cc, void PhantomLWECt::CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, RLWECt *out) const confuses me. The following code

if (multiplier == num_coeff) {
      fixed_mul = ntt_tables[l].inv_degree_modulo();
    } else {
      // compute multiplier^{-1} mod p
      uint64_t inv_multiplier;
      SPU_ENFORCE(
          try_invert_uint_mod(multiplier, modulus[l], inv_multiplier),
          fmt::format("inverse mod for multiplier={} failed", multiplier));
      fixed_mul.set(negate_uint_mod(inv_multiplier, modulus[l]), modulus[l]);
    }

negates inv_multiplier in the second branch but not in the first branch. After a careful consideration, I think fixed_mul.set(negate_uint_mod(inv_multiplier, modulus[l]), modulus[l]); should be deleted, because the following code really need fixed_mul = multiplier^{-1} mod p without an extra minus sign.

Standalone code to reproduce the issue

//original code
void PhantomLWECt::CastAsRLWE(const seal::SEALContext &context,
                              uint64_t multiplier, RLWECt *out) const {
  SPU_ENFORCE(out != nullptr);
  if (!IsValid()) {
    out->release();
    return;
  }

  auto cntxt_data = context.get_context_data(parms_id());
  SPU_ENFORCE(cntxt_data != nullptr, "invalid pid for this context");

  out->resize(context, parms_id(), 2);
  const auto &modulus = cntxt_data->parms().coeff_modulus();
  const auto *ntt_tables = cntxt_data->small_ntt_tables();
  auto num_modulus = this->coeff_modulus_size();
  auto num_coeff = this->poly_modulus_degree();

  const uint64_t *src_ptr = base_->data(1);
  uint64_t *dst_ptr = out->data(1);

  std::fill_n(out->data(0), num_coeff * num_modulus, 0);
  for (size_t l = 0; l < num_modulus; ++l) {
    using namespace seal::util;
    // multiply N^{-1} mod p to cancel out the multiplier
    MultiplyUIntModOperand fixed_mul;
    if (multiplier == num_coeff) {
      fixed_mul = ntt_tables[l].inv_degree_modulo();
    } else {
      // compute multiplier^{-1} mod p
      uint64_t inv_multiplier;
      SPU_ENFORCE(
          try_invert_uint_mod(multiplier, modulus[l], inv_multiplier),
          fmt::format("inverse mod for multiplier={} failed", multiplier));
      fixed_mul.set(negate_uint_mod(inv_multiplier, modulus[l]), modulus[l]);
    }

    dst_ptr[0] =
        multiply_uint_mod(src_ptr[coeff_index_], fixed_mul, modulus[l]);

    size_t offset = num_coeff - coeff_index_;
    for (size_t i = 1; i < offset; ++i) {
      size_t src_rlwe_idx = coeff_index_ + i;
      dst_ptr[i] =
          multiply_uint_mod(src_ptr[src_rlwe_idx], fixed_mul, modulus[l]);
    }

    for (size_t i = offset; i < num_coeff; ++i) {
      size_t src_rlwe_idx = i - offset;
      dst_ptr[i] = multiply_uint_mod(modulus[l].value() - src_ptr[src_rlwe_idx],
                                     fixed_mul, modulus[l]);
    }

    out->data(0)[l * num_coeff] = multiply_uint_mod(
        base_->data(0)[l * num_coeff + coeff_index_], fixed_mul, modulus[l]);

    src_ptr += num_coeff;
    dst_ptr += num_coeff;
  }
}

Relevant log output

No response

anakinxc commented 7 months ago

@fionser Mind take a look?

fionser commented 7 months ago

@maths644311798 Nice catch. It is a bug.