tuneinsight / lattigo

A library for lattice-based multiparty homomorphic encryption in Go
Apache License 2.0
1.21k stars 180 forks source link

Questions about computation speed in CKKS #204

Closed pen9u1nlee closed 2 years ago

pen9u1nlee commented 2 years ago

Hello, I encounter a problem when I am using Add and Mul in CKKS. The code shown below is runnable for demonstrating the problem. I generate 3 ciphertexts (a, b and x) to compute a*x + b and a*x + b*x. a*x + b runs fluently, but a*x + b*x takes up a plenty of time and memory (60+ GB or OOM killed), especially in the add after two multiplications. Also, when I do multiple a*x + b*xs and sum them together (which is not shown here), it panics that receiver operand degree is too small. I noticed that there was a bug mentioned in issues which is about the ciphertext level of Mul&Relin. Perhaps this bug reborns, or it is because of my improper use of this library?

package main

import (
    "fmt"

    "github.com/tuneinsight/lattigo/v3/ckks"
    "github.com/tuneinsight/lattigo/v3/ckks/bootstrapping"
    "github.com/tuneinsight/lattigo/v3/drlwe"
    "github.com/tuneinsight/lattigo/v3/ring"
    "github.com/tuneinsight/lattigo/v3/rlwe"
    "github.com/tuneinsight/lattigo/v3/utils"
)

type testContext struct {
    parties int
    params  ckks.Parameters

    ringQ *ring.Ring
    ringP *ring.Ring

    encoder   ckks.Encoder
    evaluator ckks.Evaluator

    /* Encryptor with public key */
    encryptorPk ckks.Encryptor
    /* Decryptor with secret key */
    decryptorSk ckks.Decryptor

    pk *rlwe.PublicKey

    sk *rlwe.SecretKey

    skShares []*rlwe.SecretKey

    crs            drlwe.CRS
    uniformSampler *ring.UniformSampler
}

func Test() {

    params, err := ckks.NewParametersFromLiteral(bootstrapping.N15QP880H16384H32.SchemeParams)
    if err != nil {
        panic(err)
    }

    // init params
    var tc *testContext = new(testContext)
    tc.parties = 3

    fmt.Println("Init params complete!")

    if err = genContext(params, tc); err != nil {
        panic(err)
    }

    // prepare some operands
    a := []float64{1., 2., 3., 4., 5., 6., 7., 8., 9., 10.}
    x := []float64{11., 12., 13., 14., 15., 16., 17., 18., 19., 20.}
    b := []float64{21., 22., 23., 24., 25., 26., 27., 28., 29., 30.}

    // ciphertext generation
    var plain *ckks.Plaintext
    plain = tc.encoder.EncodeNew(a, tc.params.MaxLevel(), tc.params.DefaultScale(), tc.params.LogSlots())
    enc_a := tc.encryptorPk.EncryptNew(plain)
    plain = tc.encoder.EncodeNew(b, tc.params.MaxLevel(), tc.params.DefaultScale(), tc.params.LogSlots())
    enc_b := tc.encryptorPk.EncryptNew(plain)
    plain = tc.encoder.EncodeNew(x, tc.params.MaxLevel(), tc.params.DefaultScale(), tc.params.LogSlots())
    enc_x := tc.encryptorPk.EncryptNew(plain)

    var rn *ckks.Ciphertext = ckks.NewCiphertext(tc.params, 1<<tc.params.LogSlots(), tc.params.MaxLevel(), tc.params.DefaultScale()) // new(ckks.Ciphertext)
    var grad *ckks.Ciphertext = ckks.NewCiphertext(tc.params, 1<<tc.params.LogSlots(), tc.params.MaxLevel(), tc.params.DefaultScale())
    // tc.evaluator.DropLevel(grad, 1)

    /* Test starts */
    /* I have no idea of the reason why the computation speed is that slow. */
    /* This is very fast... */
    fmt.Println("Test of a*x + b")
    tc.evaluator.MulRelin(enc_a, enc_x, rn)
    tc.evaluator.Add(rn, enc_b, rn)
    fmt.Println("This is actually not slow. ")

    /* But this is annoying... */
    fmt.Println("Test of a*x + b*x")
    tc.evaluator.MulRelin(enc_a, enc_x, rn)
    tc.evaluator.MulRelin(enc_b, enc_x, grad)
    fmt.Println("Mul complete")
    /* Especially here... */
    tc.evaluator.Add(rn, grad, rn)
    fmt.Println("Add complete")

    fmt.Println("Done. ")
}

func genContext(params ckks.Parameters, testCtx *testContext) (err error) {
    fmt.Println("Generating context...")

    testCtx.params = params

    testCtx.ringQ = params.RingQ()
    testCtx.ringP = params.RingP()

    /* pesudo-random number generator */
    prng, _ := utils.NewKeyedPRNG([]byte{'1', '1', '4', '5', '1', '4'})
    testCtx.crs = prng
    testCtx.uniformSampler = ring.NewUniformSampler(prng, params.RingQ())

    testCtx.encoder = ckks.NewEncoder(testCtx.params)

    kgen := ckks.NewKeyGenerator(testCtx.params)

    fmt.Println("Secret key generating...")
    /* SecretKeys */
    testCtx.skShares = make([]*rlwe.SecretKey, testCtx.parties)
    /* Init secret key */
    testCtx.sk = ckks.NewSecretKey(testCtx.params)

    ringQP, levelQ, levelP := params.RingQP(), params.QCount()-1, params.PCount()-1
    for j := 0; j < testCtx.parties; j++ {
        testCtx.skShares[j] = kgen.GenSecretKey()
        ringQP.AddLvl(levelQ, levelP, testCtx.sk.Value, testCtx.skShares[j].Value, testCtx.sk.Value)
    }

    /* Evaluation key: relin & rot */
    rlk := kgen.GenRelinearizationKey(testCtx.sk, 2)

    testCtx.evaluator = ckks.NewEvaluator(testCtx.params, rlwe.EvaluationKey{Rlk: rlk})
    /* Publickeys */
    testCtx.pk = kgen.GenPublicKey(testCtx.sk)

    testCtx.encryptorPk = ckks.NewEncryptor(testCtx.params, testCtx.pk)
    testCtx.decryptorSk = ckks.NewDecryptor(testCtx.params, testCtx.sk)
    return
}

func main() {
    Test()
}
Pro7ech commented 2 years ago

So, your problem comes from these two lignes:

var rn *ckks.Ciphertext = ckks.NewCiphertext(tc.params, 1<<tc.params.LogSlots(), tc.params.MaxLevel(), tc.params.DefaultScale()) // new(ckks.Ciphertext)
var grad *ckks.Ciphertext = ckks.NewCiphertext(tc.params, 1<<tc.params.LogSlots(), tc.params.MaxLevel(), tc.params.DefaultScale())

More specifically the 1<<tc.params.LogSlots() part, which creates ciphertexts of degree 2^LogSlots.

The degree of the ciphertext is the number of polynomials forming the ciphertext. A ciphertext can be seen as a polynomial (of polynomials) with respect to the secret key, and the decryption is its evaluation at the point of the secret-key.

A plaintext is a ciphertext of degree zero : [m(x)], a fresh ciphertext is of degree one: [-as + m(x) + e, a], a ciphertext multiplied with an other ciphertext, before its relinearization, is of degree two; [-as - a^2s^2 + m(x)^2 + e, a, a^2], and so on.

So you should replace 1<<params.LogSlots() by 1 if you are creating fresh ciphertexts.

This should make everything 4^Logslots times faster and use 2^LogSlots times less memory ;)

pen9u1nlee commented 2 years ago

Thanks for your reply...Obviously I misused this library.

I thought the degree of newly generated ciphertext should be of degree 1 by default like PALISADE/SEAL, and I directly copied the code in the test.