tuneinsight / lattigo

A library for lattice-based multiparty homomorphic encryption in Go
Apache License 2.0
1.18k stars 176 forks source link

Question: [t<N] Decryption Following Public Key Generation? #432

Closed aklitzke closed 4 months ago

aklitzke commented 8 months ago

I am running into some issues decrypting after a public key generation protocol using Lattigo. I am wondering if I am doing something wrong, and was hoping the devs could correct any misunderstandings I have.

My understanding of the generation protocol from a high level is this:

  1. Each party i generates a secret s_i'
  2. Each party generates a random polynomial f_i(x) such that f_i(0) => s_i'
  3. Each party sends every other party j (including themselves) f_i(j)
  4. Each party i sums every share they received from j to create a new secret sum(f_j(i) for all j) => s_i
  5. Each party posts publicly a public key generated from a common public A and secret error e_i: A*s_i+e_i => pk_i
  6. Each party sums the public pk_i's to generate a final public key for the scheme: sum(pk_i for all i) => pk

This is all fine and it seems to match well the implementation provided by Lattigo. However, I'm a bit confused by Lattigo's implementation of the decryption steps. Generally, I would expect decryption to work roughly as follows:

  1. Given a public ciphertext encrypted for pk: encrypt(plaintext, pk) => ct
  2. Each party i performs a decrypt operation using their private key, adding a small amount of error: decrypt(ct, s_i) + e_i => pt_i
  3. Each party sends partial plaintext pt_i to the receiving party
  4. The receiving party performs lagrange interpolation on the received pt_i shares sum(pt_i * lagrange(i) for i) => pt
  5. The result, pt, is the original plaintext

However, Lattigo seems to do decryption differently. When following the MHE readme and natural flow of the code, the lagrange coeffecient is multiplied by the secret keys prior to partially decrypting, rather than after. So, step 8 actually looks like decrypt(ct, s_i * lagrange(i)) => pt_i and step 10 is simply sum(pt_i). This feels odd, as decryption generally involves addition/subtraction and it is unclear to me how multiplying by the lagrange coeff prior to that would affect the final result. In addition, assuming this works, it requires each party to know ahead of time which shares will be aggregated, which may not always be the case.

Alas, when I tried to implement this, the code runs successfully but I cannot seem to actually reconstitute the plaintext:

    params, err := rlwe.NewParametersFromLiteral(rlwe.ExampleParametersLogN14LogQP438)
    if err != nil {
        panic(err)
    }
    fmt.Println(params)

    crs, err := sampling.NewPRNG()
    if err != nil {
        panic(err)
    }
    fmt.Println(crs)

    kgen := rlwe.NewKeyGenerator(params)
    // each party generates a key
    sk1 := kgen.GenSecretKeyNew()
    sk2 := kgen.GenSecretKeyNew()
    sk3 := kgen.GenSecretKeyNew()

    // each party generates a polynomial
    thr := mhe.NewThresholdizer(params)
    poly1, err := thr.GenShamirPolynomial(2, sk1)
    if err != nil {
        panic(err)
    }
    poly2, err := thr.GenShamirPolynomial(2, sk2)
    if err != nil {
        panic(err)
    }
    poly3, err := thr.GenShamirPolynomial(2, sk3)
    if err != nil {
        panic(err)
    }

    // each party generates shares for each party
    poly1_1 := thr.AllocateThresholdSecretShare()
    poly1_2 := thr.AllocateThresholdSecretShare()
    poly1_3 := thr.AllocateThresholdSecretShare()
    poly2_1 := thr.AllocateThresholdSecretShare()
    poly2_2 := thr.AllocateThresholdSecretShare()
    poly2_3 := thr.AllocateThresholdSecretShare()
    poly3_1 := thr.AllocateThresholdSecretShare()
    poly3_2 := thr.AllocateThresholdSecretShare()
    poly3_3 := thr.AllocateThresholdSecretShare()
    thr.GenShamirSecretShare(1, poly1, &poly1_1)
    thr.GenShamirSecretShare(2, poly1, &poly1_2)
    thr.GenShamirSecretShare(3, poly1, &poly1_3)
    thr.GenShamirSecretShare(1, poly2, &poly2_1)
    thr.GenShamirSecretShare(2, poly2, &poly2_2)
    thr.GenShamirSecretShare(3, poly2, &poly2_3)
    thr.GenShamirSecretShare(1, poly3, &poly3_1)
    thr.GenShamirSecretShare(2, poly3, &poly3_2)
    thr.GenShamirSecretShare(3, poly3, &poly3_3)

    // those shares are distributed accordingly and each party sums to create a new secret
    tmp := thr.AllocateThresholdSecretShare()
    t_sk1 := thr.AllocateThresholdSecretShare()
    t_sk2 := thr.AllocateThresholdSecretShare()
    t_sk3 := thr.AllocateThresholdSecretShare()
    thr.AggregateShares(poly1_1, poly2_1, &tmp)
    thr.AggregateShares(tmp, poly3_1, &t_sk1)
    thr.AggregateShares(poly1_2, poly2_2, &tmp)
    thr.AggregateShares(tmp, poly3_2, &t_sk2)
    thr.AggregateShares(poly1_3, poly2_3, &tmp)
    thr.AggregateShares(tmp, poly3_3, &t_sk3)

    ckg := mhe.NewPublicKeyGenProtocol(params)
    pp1_3 := []mhe.ShamirPublicPoint{1, 3}
    pp2_3 := []mhe.ShamirPublicPoint{2, 3}
    pp1_2 := []mhe.ShamirPublicPoint{1, 2}
    pp1_2_3 := []mhe.ShamirPublicPoint{1, 2, 3}

    // Just a type conversion from a shamir point into an actual secret key
    com1 := mhe.NewCombiner(*params.GetRLWEParameters(), 1, pp2_3, 2)
    com2 := mhe.NewCombiner(*params.GetRLWEParameters(), 2, pp1_3, 2)
    com3 := mhe.NewCombiner(*params.GetRLWEParameters(), 3, pp1_2, 2)
    com1.GenAdditiveShare(pp1_2_3, 1, t_sk1, sk1)
    com2.GenAdditiveShare(pp1_2_3, 2, t_sk2, sk2)
    com3.GenAdditiveShare(pp1_2_3, 3, t_sk3, sk3)

    // Generate public shared 'A'
    crp := ckg.SampleCRP(crs)

    // Generate public key from each private key
    pk1 := ckg.AllocateShare()
    pk2 := ckg.AllocateShare()
    pk3 := ckg.AllocateShare()
    ckg.GenShare(sk1, crp, &pk1)
    ckg.GenShare(sk2, crp, &pk2)
    ckg.GenShare(sk3, crp, &pk3)

    // sum up the public keys to generate a master public key
    pk_sum := ckg.AllocateShare()
    tmp2 := ckg.AllocateShare()
    ckg.AggregateShares(pk1, pk2, &tmp2)
    ckg.AggregateShares(tmp2, pk3, &pk_sum)

    // type conversion to create actual pk
    pk := rlwe.NewPublicKey(params)
    ckg.GenPublicKey(pk_sum, crp, pk)

    // encrypt something with our new pk! Just use a plaintext of all zeros
    pt := rlwe.NewPlaintext(params, params.MaxLevel())
    encr := rlwe.NewEncryptor(params, pk)
    ct, err := encr.EncryptNew(pt)
    if err != nil {
        panic(err)
    }

    // Each party partially decrypts the ciphertext by doing a keyswitch to a secret key of '0'
    sk0 := rlwe.NewSecretKey(params)
        // for testing only, zero the noise flooding parameters to remove this as a factor that could be causing issues
    ksw, err := mhe.NewKeySwitchProtocol(params, ring.DiscreteGaussian{Sigma: 0 * rlwe.DefaultNoise, Bound: 0 * 8 * rlwe.DefaultNoise})
    if err != nil {
        panic(err)
    }
    sh1 := ksw.AllocateShare(params.MaxLevel())
    sh2 := ksw.AllocateShare(params.MaxLevel())
    sh3 := ksw.AllocateShare(params.MaxLevel())
    ksw.GenShare(sk1, sk0, ct, &sh1)
    ksw.GenShare(sk2, sk0, ct, &sh2)
    ksw.GenShare(sk3, sk0, ct, &sh3)

    // Shares are aggregated
    sh := ksw.AllocateShare(params.MaxLevel())
    sh_tmp := ksw.AllocateShare(params.MaxLevel())
    ct_new := ct.CopyNew()
    ksw.AggregateShares(sh1, sh2, &sh_tmp)
    ksw.AggregateShares(sh_tmp, sh3, &sh)
    ksw.KeySwitch(ct, sh, ct_new)

    // actually decrypt
    dec := rlwe.NewDecryptor(params, sk0)
    pt_new := dec.DecryptNew(ct_new)

    // should be zeros, but fails and produces a seemingly random output
    fmt.Println(pt_new)

Ultimately I have two questions:

  1. What do I need to do differently in the above code to make [t<N] secret sharing work?
  2. Why multiply lagrange(i) by s_i ahead of partial decryption?

Thanks!! Andrew

ChristianMct commented 8 months ago

Hi @aklitzke,

Your observation that Lattigo's implementation of T-out-of-N-threshold decryption is different from the "usual" ones is correct. It is based on the scheme of this paper, for which it is a requirement to know the group of >=T parties participating to the decryption before generating the decryption shares.

Here is the flow of this scheme:

  1. Given a public ciphertext encrypted for pk: encrypt(plaintext, pk) => ct and a decryption group P' subset of P (the set of all parties) of size T.
  2. Each party compute the Lagrange coefficient for its own share s_i in the reconstruction among group P', and pre-multiply it to their si => s{i, P'}
  3. The parties execute the usual decryption protocol, yet with s_{i,P'} (their additive T-out-of-T share) instead of s'_i (their additive N-out-of-N share).

In the above, step 2 is implemented with the drlwe.Combiner type, the use of which is not exactly correct in your code. The NewCombiner method takes all ShamirPublicPoint public points (pp1_2_3 in your code) as its other parameter. Then, the GenAdditiveShare method lets you compute the s_{i, P'} of step 2.

Note: The reason for implementing a more restricted scheme is that there isn't yet (to the best of our knowledge) any scheme that implements the "traditional" T-out-of-N-decryption flow. In a nutshell, the problem is that you cannot simply multiply your decryption shares by the Lagrange coefficient, because this would make the noise blow up. The above scheme circumvent this limitation as pre-multiplying the coefficient to the secret does not affect the noise.

aklitzke commented 8 months ago

Thanks for the reply! The approach you describe makes sense: to stop the error from growing when multiplying by Lagrange coefficients you multiply them by the secret key ahead of partial decryption. Got it, I like the approach. It's smart.

On a separate note, I unfortunately still can't get the library to work. I did what you suggested and passed pp1_2_3 into NewCombiner:

    com1 := mhe.NewCombiner(*params.GetRLWEParameters(), 1, pp1_2_3, 2)
    com2 := mhe.NewCombiner(*params.GetRLWEParameters(), 2, pp1_2_3, 2)
    com3 := mhe.NewCombiner(*params.GetRLWEParameters(), 3, pp1_2_3, 2)
    com1.GenAdditiveShare(pp1_2_3, 1, t_sk1, sk1)
    com2.GenAdditiveShare(pp1_2_3, 2, t_sk2, sk2)
    com3.GenAdditiveShare(pp1_2_3, 3, t_sk3, sk3)

However, decryption still produced seemingly random output. I'm not sure I understand your suggestion. It's important that all parties keys are passed into GenAdditiveShare. However, when taking a look at the NewCombiner(...) source in threshold.go:

    // precomputes lagrange coefficient factors
    cmb.lagrangeCoeffs = make(map[ShamirPublicPoint]ring.RNSScalar)
    for _, spk := range others {
        if spk != own {

because of the spk != own check, NewCombiner will produce the same output whether or not a parties 'own' point is passed into others. Maybe I'm misunderstanding your suggestion?

ChristianMct commented 8 months ago

Hi,

I believe the problems in your code are the following:

Here is an updated main to illustrate the above (look for // EDIT: comments).

package main

import (
    "fmt"

    "github.com/tuneinsight/lattigo/v5/core/rlwe"
    "github.com/tuneinsight/lattigo/v5/mhe"
    "github.com/tuneinsight/lattigo/v5/ring"
    "github.com/tuneinsight/lattigo/v5/utils/sampling"
)

func main() {
    params, err := rlwe.NewParametersFromLiteral(rlwe.ExampleParametersLogN14LogQP438)
    if err != nil {
        panic(err)
    }
    fmt.Println(params)

    crs, err := sampling.NewPRNG()
    if err != nil {
        panic(err)
    }
    fmt.Println(crs)

    kgen := rlwe.NewKeyGenerator(params)
    // each party generates a key
    sk1 := kgen.GenSecretKeyNew()
    sk2 := kgen.GenSecretKeyNew()
    sk3 := kgen.GenSecretKeyNew()

    // each party generates a polynomial
    thr := mhe.NewThresholdizer(params)
    poly1, err := thr.GenShamirPolynomial(2, sk1)
    if err != nil {
        panic(err)
    }
    poly2, err := thr.GenShamirPolynomial(2, sk2)
    if err != nil {
        panic(err)
    }
    poly3, err := thr.GenShamirPolynomial(2, sk3)
    if err != nil {
        panic(err)
    }

    // each party generates shares for each party
    poly1_1 := thr.AllocateThresholdSecretShare()
    poly1_2 := thr.AllocateThresholdSecretShare()
    poly1_3 := thr.AllocateThresholdSecretShare()
    poly2_1 := thr.AllocateThresholdSecretShare()
    poly2_2 := thr.AllocateThresholdSecretShare()
    poly2_3 := thr.AllocateThresholdSecretShare()
    poly3_1 := thr.AllocateThresholdSecretShare()
    poly3_2 := thr.AllocateThresholdSecretShare()
    poly3_3 := thr.AllocateThresholdSecretShare()
    thr.GenShamirSecretShare(1, poly1, &poly1_1)
    thr.GenShamirSecretShare(2, poly1, &poly1_2)
    thr.GenShamirSecretShare(3, poly1, &poly1_3)
    thr.GenShamirSecretShare(1, poly2, &poly2_1)
    thr.GenShamirSecretShare(2, poly2, &poly2_2)
    thr.GenShamirSecretShare(3, poly2, &poly2_3)
    thr.GenShamirSecretShare(1, poly3, &poly3_1)
    thr.GenShamirSecretShare(2, poly3, &poly3_2)
    thr.GenShamirSecretShare(3, poly3, &poly3_3)

    // those shares are distributed accordingly and each party sums to create a new secret
    tmp := thr.AllocateThresholdSecretShare()
    t_sk1 := thr.AllocateThresholdSecretShare()
    t_sk2 := thr.AllocateThresholdSecretShare()
    t_sk3 := thr.AllocateThresholdSecretShare()
    thr.AggregateShares(poly1_1, poly2_1, &tmp)
    thr.AggregateShares(tmp, poly3_1, &t_sk1)
    thr.AggregateShares(poly1_2, poly2_2, &tmp)
    thr.AggregateShares(tmp, poly3_2, &t_sk2)
    thr.AggregateShares(poly1_3, poly2_3, &tmp)
    thr.AggregateShares(tmp, poly3_3, &t_sk3)

    ckg := mhe.NewPublicKeyGenProtocol(params)
    pp1_3 := []mhe.ShamirPublicPoint{1, 3}
    pp1_2 := []mhe.ShamirPublicPoint{1, 2}
    pp1_2_3 := []mhe.ShamirPublicPoint{1, 2, 3}

    // EDIT: create the parties' Combiners
    com1 := mhe.NewCombiner(*params.GetRLWEParameters(), 1, pp1_2_3, 2)
    com2 := mhe.NewCombiner(*params.GetRLWEParameters(), 2, pp1_2_3, 2)
    com3 := mhe.NewCombiner(*params.GetRLWEParameters(), 3, pp1_2_3, 2)

    // Generate public shared 'A'
    crp := ckg.SampleCRP(crs)

    // Generate public key from each private key

    // EDIT: step 1, obtain the parties' 2-out-of-2 additive shares
    sk12_1, sk12_2 := rlwe.NewSecretKey(params), rlwe.NewSecretKey(params)
    com1.GenAdditiveShare(pp1_2, 1, t_sk1, sk12_1)
    com2.GenAdditiveShare(pp1_2, 2, t_sk2, sk12_2)

    // step 2: run the CKG protocol among 2 parties, **with the 2-out-of-2 keys**
    pk1 := ckg.AllocateShare()
    pk2 := ckg.AllocateShare()
    ckg.GenShare(sk12_1, crp, &pk1)
    ckg.GenShare(sk12_2, crp, &pk2)

    // sum up the public keys to generate a master public key
    pk_sum := ckg.AllocateShare()
    ckg.AggregateShares(pk1, pk2, &pk_sum)

    // type conversion to create actual pk
    pk := rlwe.NewPublicKey(params)
    ckg.GenPublicKey(pk_sum, crp, pk)

    // encrypt something with our new pk! Just use a plaintext of all zeros
    pt := rlwe.NewPlaintext(params, params.MaxLevel())
    pt.IsNTT = false // EDIT: little technicality, using the rlwe package directly without encoding requires specifying if the message is in NTT or not
    encr := rlwe.NewEncryptor(params, pk)
    ct, err := encr.EncryptNew(pt)
    if err != nil {
        panic(err)
    }

    // Each party partially decrypts the ciphertext by doing a keyswitch to a secret key of '0'
    sk0 := rlwe.NewSecretKey(params)
    // for testing only, zero the noise flooding parameters to remove this as a factor that could be causing issues
    ksw, err := mhe.NewKeySwitchProtocol(params, ring.DiscreteGaussian{Sigma: 0 * rlwe.DefaultNoise, Bound: 0 * 8 * rlwe.DefaultNoise})
    if err != nil {
        panic(err)
    }

    // EDIT: now performing the decryption among another set of 2 parties
    // Step 1: obtaining the 2-out-of-2 shares
    sk13_1, sk13_3 := rlwe.NewSecretKey(params), rlwe.NewSecretKey(params)
    com1.GenAdditiveShare(pp1_3, 1, t_sk1, sk13_1)
    com3.GenAdditiveShare(pp1_3, 3, t_sk3, sk13_3)

    // Step 2: run the decryption among 2 parties
    sh1 := ksw.AllocateShare(params.MaxLevel())
    sh3 := ksw.AllocateShare(params.MaxLevel())
    ksw.GenShare(sk13_1, sk0, ct, &sh1)
    ksw.GenShare(sk13_3, sk0, ct, &sh3)

    // Shares are aggregated
    sh := ksw.AllocateShare(params.MaxLevel())
    ct_new := ct.CopyNew()
    ksw.AggregateShares(sh1, sh3, &sh)
    ksw.KeySwitch(ct, sh, ct_new)

    // actually decrypt
    dec := rlwe.NewDecryptor(params, sk0)
    pt_new := dec.DecryptNew(ct_new)

    // EDIT: now the output doesn't look random anymore (but is not zero because of noise)
    fmt.Println(pt_new)

    // EDIT: you can print the norm of the vector and see that it is small (ie would be decoded to zero in schemes like BGV/BFV/CKKS)
    fmt.Println("pt_new norm = ", params.RingQ().Log2OfStandardDeviation(pt_new.Value))
}