zama-ai / tfhe-rs

TFHE-rs: A Pure Rust implementation of the TFHE Scheme for Boolean and Integer Arithmetics Over Encrypted Data.
Other
964 stars 147 forks source link

Wrong LUT encoding in case `en(de)crypt_without_padding` is used #116

Closed fakub closed 1 year ago

fakub commented 1 year ago

Describe the bug

The combination of en(de)crypt_without_padding together with generate_accumulator gives incorrect results. E.g., with func = |x:u64| x + 4, the output after decryption shows that rather x + 2 has been evaluated.

My conjecture is that omitting padding with en(de)crypt_without_padding is not addressed correctly in generate_accumulator: it seems to be transforming x -> 2x for the LUT evaluation (~ padding?) and back as y -> y/2, effectively evaluating ((2x) + 4) / 2 = x + 2, which is what I observe instead of the expected x + 4.

To Reproduce

Demonstrated in a toy example.

Evidence

Output of the toy example (parts omitted) with PARAM_MESSAGE_5_CARRY_0:

>>> DEMO PBS with TFHE-rs, function:   x -> x + 4

(i =  0) decrypted IN -> OUT:  0 ->  2
(i =  1) decrypted IN -> OUT:  1 ->  3
(i =  2) decrypted IN -> OUT:  2 ->  4
...
(i = 14) decrypted IN -> OUT: 14 -> 16
(i = 15) decrypted IN -> OUT: 15 -> 17
(i = 16) decrypted IN -> OUT: 16 -> 30   // negacyclic overlap as expected in 5-bit msg space
(i = 17) decrypted IN -> OUT: 17 -> 29
(i = 18) decrypted IN -> OUT: 18 -> 28
...
(i = 30) decrypted IN -> OUT: 30 -> 16
(i = 31) decrypted IN -> OUT: 31 -> 15
(i = 32) decrypted IN -> OUT:  0 ->  2   // indeed 5 bits of msg space

Configuration(please complete the following information):

Cargo.lock

IceTDrinker commented 1 year ago

Generate accumulator assumes a bit of padding as it's meant to be used for a PBS I believe, this might explain the issue you are seeing.

Could you post the a minified version of the code triggering the bug?

IceTDrinker commented 1 year ago

Ok just saw you linked a toy example, we'll investigate further ! But I think it's the assumption that the accumulator generation expects the LUT to be used with a PBS and so always provisions the bit of padding (not thinking about already negacyclic functions)

fakub commented 1 year ago

that's exactly the point: if I do wanna use negacyclic functions, I need en(de)crypt_without_padding, because otherwise I cannot encrypt "negative" values, i.e., those with 1 at MSB. after a small investigation (not in the code, though), it seems to me that the problem is that en(de)crypt_without_padding split the torus into less pieces than the parameters allow (and usually do when padding is used).

what it seems to me is that in my example with PARAM_MESSAGE_5_CARRY_0, standard encryption occupies 6 bits (MSB), whereas encrypt_without_padding only consumes 5 bits, which effectively multiplies input messages by 2 (~left shift). then, PBS again assumes 6 bits (1 bit of padding + 5 bits) and evaluates respective LUT. finally, decrypt_without_padding pops out only 5 bits (again, MSB), which effectively divides the result by 2 (~right shift).

so I think that changing the encoding step of en(de)crypt_without_padding would fix this issue, but I don't really know whether other code depends on these functions. an example covering the behavior of en(de)crypt_without_padding together with PBS would equally be helpful, I believe :)

IceTDrinker commented 1 year ago

Sorry, forgot to answer here, my feeling is that we are maybe missing a generate_accumulator_without_padding that could be used in unison with encrypt/decrypt without padding, though other primitives using the PBS assume the presence of a padding bit, not sure how the design should evolve yet

IceTDrinker commented 1 year ago

Updating the issue with the last progress on this shared on Discord

The following code works as expected (by re-using the accumulator generator from parmesan found here: https://github.com/fakub/parmesan/blob/0f5c8fdff153134f727281112bd395ef4642a260/src/cloudovo/pbs.rs#L110)

use tfhe::core_crypto::entities::GlweCiphertext;
use tfhe::shortint::ciphertext::Degree;
use tfhe::shortint::parameters::*;
use tfhe::shortint::prelude::*;
use tfhe::shortint::server_key::LookupTableOwned;

use rayon::prelude::*;

fn gen_no_padding_acc<F>(server_key: &ServerKey, f: F) -> LookupTableOwned
where
    F: Fn(u64) -> u64,
{
    let mut accumulator = GlweCiphertext::new(
        0u64,
        server_key.bootstrapping_key.glwe_size(),
        server_key.bootstrapping_key.polynomial_size(),
        server_key.key_switching_key.ciphertext_modulus(),
    );

    let mut accumulator_view = accumulator.as_mut_view();

    accumulator_view.get_mut_mask().as_mut().fill(0);

    // Modulus of the msg contained in the msg bits and operations buffer
    // Modulus_sup is divided by two as in parmesan
    let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0 / 2;

    // N/(p/2) = size of each block
    let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup;

    // Value of the shift we multiply our messages by
    // First main change delta is re multiplied by 2 to account for the padding bit
    let delta =
        ((1u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64) * 2;

    let mut body = accumulator_view.get_mut_body();
    let accumulator_u64 = body.as_mut();

    // Tracking the max value of the function to define the degree later
    let mut max_value = 0;

    for i in 0..modulus_sup {
        let index = i * box_size;
        accumulator_u64[index..index + box_size]
            .iter_mut()
            .for_each(|a| {
                let f_eval = f(i as u64);
                *a = f_eval * delta;
                max_value = max_value.max(f_eval);
            });
    }

    let half_box_size = box_size / 2;

    // Negate the first half_box_size coefficients
    for a_i in accumulator_u64[0..half_box_size].iter_mut() {
        *a_i = (*a_i).wrapping_neg();
    }

    // Rotate the accumulator
    accumulator_u64.rotate_left(half_box_size);

    LookupTableOwned {
        acc: accumulator,
        degree: Degree(max_value as usize),
    }
}

fn main() {
    println!("\n>>> DEMO PBS with TFHE-rs, function:   x -> x + 4\n");

    // custom params as shortint ignores the bit of padding when defining the message space in the parameter name
    let custom_params = Parameters {
        message_modulus: MessageModulus(1 << 5),
        ..PARAM_MESSAGE_4_CARRY_0
    };

    // generate keys for 5-bit msg space
    let (client_key, server_key) = gen_keys(custom_params);

    // create test vec from func (nice to have: generate_accumulator that inputs a LUT as a vector)
    let func = |x: u64| x + 4;
    // let acc = server_key.generate_accumulator(func);
    let acc = gen_no_padding_acc(&server_key, func);

    // -------------------------------------------------------------------------
    // full loop with one extra value to spot the overlap
    for i in 0..=32 {
        // encrypt into the full msg domain
        let c1 = client_key.encrypt_without_padding(i);

        // eval func
        let cf1 = server_key.apply_lookup_table(&c1, &acc);

        // decrypt (include the padding bit) & print
        let m1 = client_key.decrypt_without_padding(&c1);
        let mf1 = client_key.decrypt_without_padding(&cf1);
        println!("(i = {:2}) decrypted IN -> OUT: {:2} -> {:2}", i, m1, mf1);
    }

    // -------------------------------------------------------------------------
    // the same thing in parallel
    println!("\n\n>>> Same thing in parallel\n");
    let mut c_in: Vec<_> = vec![];
    let mut c_out: Vec<_> = vec![];
    for i in 0..=32 {
        // encrypt into the full msg domain
        c_in.push(client_key.encrypt_without_padding(i));
        // well, this is a bit useless
        c_out.push(client_key.encrypt_without_padding(i));
    }

    c_out
        .par_iter_mut()
        .zip(c_in.par_iter())
        .for_each(|(co, ci)| {
            // eval func
            *co = server_key.apply_lookup_table(&ci, &acc);
        });
    c_out
        .iter()
        .zip(c_in.iter().enumerate())
        .for_each(|(co, (i, ci))| {
            // decrypt (include the padding bit) & print
            let mi = client_key.decrypt_without_padding(&ci);
            let mfi = client_key.decrypt_without_padding(&co);
            println!("(i = {:2}) decrypted IN -> OUT: {:2} -> {:2}", i, mi, mfi);
        });
}

Sample output with negacyclic overlap starting at 16:

>>> DEMO PBS with TFHE-rs, function:   x -> x + 4

(i =  0) decrypted IN -> OUT:  0 ->  4
(i =  1) decrypted IN -> OUT:  1 ->  5
(i =  2) decrypted IN -> OUT:  2 ->  6
(i =  3) decrypted IN -> OUT:  3 ->  7
(i =  4) decrypted IN -> OUT:  4 ->  8
(i =  5) decrypted IN -> OUT:  5 ->  9
(i =  6) decrypted IN -> OUT:  6 -> 10
(i =  7) decrypted IN -> OUT:  7 -> 11
(i =  8) decrypted IN -> OUT:  8 -> 12
(i =  9) decrypted IN -> OUT:  9 -> 13
(i = 10) decrypted IN -> OUT: 10 -> 14
(i = 11) decrypted IN -> OUT: 11 -> 15
(i = 12) decrypted IN -> OUT: 12 -> 16
(i = 13) decrypted IN -> OUT: 13 -> 17
(i = 14) decrypted IN -> OUT: 14 -> 18
(i = 15) decrypted IN -> OUT: 15 -> 19
(i = 16) decrypted IN -> OUT: 16 -> 28
(i = 17) decrypted IN -> OUT: 17 -> 27
(i = 18) decrypted IN -> OUT: 18 -> 26
(i = 19) decrypted IN -> OUT: 19 -> 25
(i = 20) decrypted IN -> OUT: 20 -> 24
(i = 21) decrypted IN -> OUT: 21 -> 23
(i = 22) decrypted IN -> OUT: 22 -> 22
(i = 23) decrypted IN -> OUT: 23 -> 21
(i = 24) decrypted IN -> OUT: 24 -> 20
(i = 25) decrypted IN -> OUT: 25 -> 19
(i = 26) decrypted IN -> OUT: 26 -> 18
(i = 27) decrypted IN -> OUT: 27 -> 17
(i = 28) decrypted IN -> OUT: 28 -> 16
(i = 29) decrypted IN -> OUT: 29 -> 15
(i = 30) decrypted IN -> OUT: 30 -> 14
(i = 31) decrypted IN -> OUT: 31 -> 13
(i = 32) decrypted IN -> OUT:  0 ->  4

>>> Same thing in parallel

(i =  0) decrypted IN -> OUT:  0 ->  4
(i =  1) decrypted IN -> OUT:  1 ->  5
(i =  2) decrypted IN -> OUT:  2 ->  6
(i =  3) decrypted IN -> OUT:  3 ->  7
(i =  4) decrypted IN -> OUT:  4 ->  8
(i =  5) decrypted IN -> OUT:  5 ->  9
(i =  6) decrypted IN -> OUT:  6 -> 10
(i =  7) decrypted IN -> OUT:  7 -> 11
(i =  8) decrypted IN -> OUT:  8 -> 12
(i =  9) decrypted IN -> OUT:  9 -> 13
(i = 10) decrypted IN -> OUT: 10 -> 14
(i = 11) decrypted IN -> OUT: 11 -> 15
(i = 12) decrypted IN -> OUT: 12 -> 16
(i = 13) decrypted IN -> OUT: 13 -> 17
(i = 14) decrypted IN -> OUT: 14 -> 18
(i = 15) decrypted IN -> OUT: 15 -> 19
(i = 16) decrypted IN -> OUT: 16 -> 28
(i = 17) decrypted IN -> OUT: 17 -> 27
(i = 18) decrypted IN -> OUT: 18 -> 26
(i = 19) decrypted IN -> OUT: 19 -> 25
(i = 20) decrypted IN -> OUT: 20 -> 24
(i = 21) decrypted IN -> OUT: 21 -> 23
(i = 22) decrypted IN -> OUT: 22 -> 22
(i = 23) decrypted IN -> OUT: 23 -> 21
(i = 24) decrypted IN -> OUT: 24 -> 20
(i = 25) decrypted IN -> OUT: 25 -> 19
(i = 26) decrypted IN -> OUT: 26 -> 18
(i = 27) decrypted IN -> OUT: 27 -> 17
(i = 28) decrypted IN -> OUT: 28 -> 16
(i = 29) decrypted IN -> OUT: 29 -> 15
(i = 30) decrypted IN -> OUT: 30 -> 14
(i = 31) decrypted IN -> OUT: 31 -> 13
(i = 32) decrypted IN -> OUT:  0 ->  4
IceTDrinker commented 1 year ago

hello @fakub is it ok to close with the proposed workaround for now ?

fakub commented 1 year ago

sure, at least in my use case, all works fine with the workaround :)