secretflow / psi

The repo of Private Set Intersection(PSI) and Private Information Retrieval(PIR) from SecretFlow.
https://www.secretflow.org.cn/docs/psi
Apache License 2.0
21 stars 16 forks source link

sealPir 其他安全强度支持(4096安全参数支持) #114

Open Yinbenxin opened 3 months ago

Yinbenxin commented 3 months ago

此问题发生在尝试使用4096安全强度来进行匿踪查询,经过为期1周的排查,问题终于浮出水面并得以解决。 问题代码: seal_pir.cc中std::vector SealPirServer::ExpandQuery函数:

std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
    const seal::Ciphertext &encrypted, std::uint32_t m) {
  uint64_t plain_mod = seal_params_->plain_modulus().value();

  seal::GaloisKeys &galkey = galois_key_;

  // Assume that m is a power of 2. If not, round it to the next power of 2.
  uint32_t logm = std::ceil(std::log2(m));

  std::vector<int> galois_elts;
  auto n = seal_params_->poly_modulus_degree();
  YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");

  galois_elts.reserve(std::ceil(std::log2(n)));
  for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
    galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
                          seal::util::exponentiate_uint(2, i));
  }

  std::vector<seal::Ciphertext> results(1);
  results[0] = encrypted;
  seal::Plaintext tempPt;
  for (size_t j = 0; j < logm - 1; j++) {
    std::vector<seal::Ciphertext> results2(1 << (j + 1));
    int step = 1 << j;
    seal::Plaintext pt0(n);
    seal::Plaintext pt1(n);

    pt0.set_zero();
    pt0[n - step] = plain_mod - 1;
    std::cout << "plain_mods:" << plain_mod << std::endl;
    int index_raw = (n << 1) - (1 << j);  // -2^j
    int index = (index_raw * galois_elts[j]) % (n << 1);
    pt1.set_zero();
    pt1[index] = 1;
    std::cout << "pt0:" << pt0.to_string() << std::endl;
    std::cout << "pt1:" << pt1.to_string() << std::endl;
    // int nstep = -step;
    yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
      for (int k = begin; k < end; k++) {
        seal::Ciphertext c0;
        seal::Ciphertext c1;
        seal::Ciphertext t0;
        seal::Ciphertext t1;

        c0 = results[k];

        // SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
        evaluator_->apply_galois(c0, galois_elts[j], galkey,
                                 t0);          // t0 = Sub(c0,N/(2^i)+1)
        evaluator_->add(c0, t0, results2[k]);  // c0 + Sub(c0,N/(2^i)+1)
        // multiply_power_of_X(c0, c1, index_raw);
        evaluator_->multiply_plain(c0, pt0, c1);  // c1 = c0*(-x)^(-2j)
        evaluator_->multiply_plain(t0, pt1, t1);
        // Sub(c0,N/(2^i)+1) * x^(-2j*(N+2^i)/(2^i))=Sub(c1,N/2^j+1)
        evaluator_->add(c1, t1, results2[k + step]);
      }
    });
    results = results2;
  }

  // Last step of the loop
  std::vector<seal::Ciphertext> results2(results.size() << 1);
  seal::Plaintext two("2");

  seal::Plaintext pt0(n);
  seal::Plaintext pt1(n);

  pt0.set_zero();
  pt0[n - results.size()] = plain_mod - 1;

  int index_raw = (n << 1) - (1 << (logm - 1));
  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
  pt1.set_zero();
  pt1[index] = 1;

  for (uint32_t k = 0; k < results.size(); k++) {
    if (k >= (m - (1 << (logm - 1)))) {  // corner case.
      evaluator_->multiply_plain(results[k], two,
                                 results2[k]);  // plain multiplication by 2.
    } else {
      seal::Ciphertext c0;
      seal::Ciphertext c1;
      seal::Ciphertext t0;
      seal::Ciphertext t1;

      c0 = results[k];
      evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
      evaluator_->add(c0, t0, results2[k]);
      // multiply_power_of_X(c0, c1, index_raw);

      evaluator_->multiply_plain(c0, pt0, c1);
      evaluator_->multiply_plain(t0, pt1, t1);
      evaluator_->add(c1, t1, results2[k + results.size()]);
    }
  }

  auto first = results2.begin();
  auto last = results2.begin() + m;
  std::vector<seal::Ciphertext> new_vec(first, last);
  return new_vec;
}

建议修改为:

std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
    const seal::Ciphertext &encrypted, std::uint32_t m) {

  seal::GaloisKeys &galkey = galois_key_;

  // Assume that m is a power of 2. If not, round it to the next power of 2.
  uint32_t logm = std::ceil(std::log2(m));

  std::vector<int> galois_elts;
  auto n = seal_params_->poly_modulus_degree();
  YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");

  galois_elts.reserve(std::ceil(std::log2(n)));
  for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
    galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
                          seal::util::exponentiate_uint(2, i));
  }

  std::vector<seal::Ciphertext> results(1);
  results[0] = encrypted;
  seal::Plaintext tempPt;
  for (size_t j = 0; j < logm - 1; j++) {
    std::vector<seal::Ciphertext> results2(1 << (j + 1));
    int step = 1 << j;

    int index_raw = (n << 1) - (1 << j); 
    int index = (index_raw * galois_elts[j]) % (n << 1);

    // int nstep = -step;
    yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
      for (int k = begin; k < end; k++) {
        seal::Ciphertext c0;
        seal::Ciphertext c1;
        seal::Ciphertext t0;
        seal::Ciphertext t1;

        c0 = results[k];
        // SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
        evaluator_->apply_galois(c0, galois_elts[j], galkey,
                                 t0);          
        evaluator_->add(c0, t0, results2[k]);  
        multiply_power_of_X(c0, c1, index_raw);
        multiply_power_of_X(t0, t1, index);

        evaluator_->add(c1, t1, results2[k + step]);
      }
    });
    results = results2;
  }

  // Last step of the loop
  std::vector<seal::Ciphertext> results2(results.size() << 1);
  seal::Plaintext two("2");

  seal::Plaintext pt0(n);
  seal::Plaintext pt1(n);

  int index_raw = (n << 1) - (1 << (logm - 1));
  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);

  for (uint32_t k = 0; k < results.size(); k++) {
    if (k >= (m - (1 << (logm - 1)))) {  // corner case.
      evaluator_->multiply_plain(results[k], two,
                                 results2[k]);  // plain multiplication by 2.
    } else {
      seal::Ciphertext c0;
      seal::Ciphertext c1;
      seal::Ciphertext t0;
      seal::Ciphertext t1;

      c0 = results[k];
      evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
      evaluator_->add(c0, t0, results2[k]);

      multiply_power_of_X(c0, c1, index_raw);
      multiply_power_of_X(t0, t1, index);
      evaluator_->add(c1, t1, results2[k + results.size()]);
    }
  }

  auto first = results2.begin();
  auto last = results2.begin() + m;
  std::vector<seal::Ciphertext> new_vec(first, last);
  return new_vec;
}

void SealPirServer::multiply_power_of_X(const seal::Ciphertext &encrypted,
                                        seal::Ciphertext &destination,
                                        uint32_t index) {
  auto coeff_mod_count = seal_params_->coeff_modulus().size() - 1;
  auto coeff_count = seal_params_->poly_modulus_degree();
  auto encrypted_count = encrypted.size();

  destination = encrypted;
  for (size_t i = 0; i < encrypted_count; i++) {
    for (size_t j = 0; j < coeff_mod_count; j++) {
      seal::util::negacyclic_shift_poly_coeffmod(
          encrypted.data(i) + (j * coeff_count), coeff_count, index,
          seal_params_->coeff_modulus()[j],
          destination.data(i) + (j * coeff_count));
    }
  }
}

主要原因是,multiply_plain会严重损耗seal密态计算的噪音,但是negacyclic_shift_poly_coeffmod不会导致噪音增大,并且在乘x^n时该函数具有更快的计算速度。

Yinbenxin commented 3 months ago

为了说明这个问题,可以用以下例子进行说明:

#include <iostream>
#include "seal/seal.h"
#include "seal/util/polyarithsmallmod.h"
using namespace std;
using namespace seal;
using namespace seal::util;
inline void multiply_power_of_X(const Ciphertext &encrypted,EncryptionParameters enc_params_,
                                           Ciphertext &destination,
                                           uint32_t index) {

    auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
    auto coeff_count = enc_params_.poly_modulus_degree();
    auto encrypted_count = encrypted.size();

    destination = encrypted;

    for (int i = 0; i < encrypted_count; i++) {
        for (int j = 0; j < coeff_mod_count; j++) {
            negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
                                           coeff_count, index,
                                           enc_params_.coeff_modulus()[j],
                                           destination.data(i) + (j * coeff_count));
        }
    }
}
int main() {
    // 初始化 SEAL 库
    int N = 4096;
    EncryptionParameters parms(scheme_type::bfv);
    parms.set_poly_modulus_degree(N);
    parms.set_coeff_modulus(CoeffModulus::BFVDefault(N));
    parms.set_plain_modulus(PlainModulus::Batching(N, 20));
    auto context = SEALContext(parms);
    uint64_t plain_mod = parms.plain_modulus().value();
    // 生成密钥
    seal::PublicKey public_key;
    seal::SecretKey secret_key;
    KeyGenerator keygen(context);
    keygen.create_public_key(public_key);
    secret_key = keygen.secret_key();
    // 创建加密器
    Encryptor encryptor(context, public_key);

    // 创建一个多项式
    Plaintext plain_coefficients(N);
    plain_coefficients.set_zero();
    plain_coefficients[1] = 10;
    // 加密多项式
    Ciphertext ciphertext;
    encryptor.encrypt(plain_coefficients, ciphertext);

    // 创建一个 x^10 的明文
    Plaintext plain_power(N);
    int step = 1 << 4;
    plain_coefficients.set_zero();
    int index_raw = (N << 1) - step;
    plain_power[N - step] = plain_mod - 1;
    Evaluator evaluator(context);
    Decryptor decryptor(context, secret_key);
    Ciphertext mpfx= ciphertext;
    Ciphertext mp= ciphertext;

    for (int i = 0; i < 4; ++i) {
        Ciphertext mpfx_result;
        Ciphertext mp_result ;
        Plaintext mpfx_plaint;
        Plaintext mp_plaint;
        multiply_power_of_X(mpfx, parms, mpfx_result,index_raw);
        evaluator.multiply_plain(mp, plain_power,mp_result);
        decryptor.decrypt(mpfx_result, mpfx_plaint);
        decryptor.decrypt(mp_result, mp_plaint);

        cout << "multiply_power_of_X result: " << mpfx_plaint.to_string().substr(0,50) << endl;
        cout << "multiply_plain result: " << mp_plaint.to_string().substr(0,50) << endl;
        cout << "multiply_power_of_X 剩余可用噪音: " << decryptor.invariant_noise_budget(mpfx_result) << endl;
        cout << "multiply_plain 剩余可用噪音: " << decryptor.invariant_noise_budget(mp_result) << endl;

        mpfx= mpfx_result;
        mp= mp_result;

    }
    return 0;
}

multiply_power_of_X result: FBFF7x^4081 multiply_plain result: FBFF7x^4081 multiply_power_of_X 剩余可用噪音: 45 multiply_plain 剩余可用噪音: 25 multiply_power_of_X result: FBFF7x^4065 multiply_plain result: FBFF7x^4065 multiply_power_of_X 剩余可用噪音: 45 multiply_plain 剩余可用噪音: 5 multiply_power_of_X result: FBFF7x^4049 multiply_plain result: FAD3Ax^4095 + 1E1x^4094 + F0x^4093 + FBF11x^4092 + multiply_power_of_X 剩余可用噪音: 45 multiply_plain 剩余可用噪音: 0 multiply_power_of_X result: FBFF7x^4033 multiply_plain result: 5DCD2x^4095 + 4065Ex^4094 + 4065Ex^4093 + 9E32Fx^4 multiply_power_of_X 剩余可用噪音: 45 multiply_plain 剩余可用噪音: 0

可以看到噪音会迅速降低,从而导致计算错误。

Yinbenxin commented 3 months ago

之前仅仅支持8192是因为查询的总量较小,噪音并未消耗完毕,在数据量较大时会出现噪音不够所导致的计算错误问题。

Jamie-Cui commented 3 months ago

@qxzhou1010 Would you mind to take a look at this?

qxzhou1010 commented 3 months ago

@Yinbenxin 非常感谢您提出这个issue,并给出了优化的实现。这里我们是想在密文下计算 c1 = c0*(-x)^(-2j),由于 BFV 中多项式模采用了非常特殊的负循环多项式(x^N+1),因此这里的乘法运算本质上就是对 c0 的负循环移位操作。所以我们可以使用 negacyclic_shift_poly_coeffmod 来加速这个运算,并且这个过程对噪声消耗是零的,因为只涉及到对密文多项式一些简单的移位操作,所以并不会增加密文中所包含的噪声。

multiply_plain 是因为涉及到密文*明文,因此结果密文中的噪声项会被放大,所以每一次操作都会导致对噪声预算的消耗。

实际上,在 SealPIR 官方仓库中正是采用的这个实现。可以参考:https://github.com/microsoft/SealPIR/blob/ee1a5a3922fc9250f9bb4e2416ff5d02bfef7e52/src/pir_server.cpp#L415

我们后续将会对这个点的实现进行优化,再次感谢您提出的问题和进行的验证。