zama-ai / tfhe-rs

TFHE-rs: A Pure Rust implementation of the TFHE Scheme for Boolean and Integer Arithmetics Over Encrypted Data.
Other
941 stars 145 forks source link

Improve WoPBS API to avoid mistakes when using ciphertexts with varying degrees, consider having a carry invariant radix LUT generation #1016

Open cgouert opened 7 months ago

cgouert commented 7 months ago

Describe the bug After a smart_mul, the WoPBS does not yield the expected answer.

To Reproduce

use std::{collections::HashMap};
use tfhe::{
    integer::{
        gen_keys_radix, wopbs::*,
    },
    shortint::parameters::{
        parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
        PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    },
};

fn foo(x: u64, lut_entries: &HashMap<u64, u64>) -> u64 {
    lut_entries[&x]
}

fn main() {
    let nb_blocks: usize = 4;

    // Generate radix keys
    let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

    // Generate key for PBS (without padding)
    let wopbs_key = WopbsKey::new_wopbs_key(
        &client_key,
        &server_key,
        &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    );

    // Create ciphertexts 
    let mut ct = client_key.encrypt(2_u64);
    let mut ct_2 = client_key.encrypt(4_u64);

    // Generate LUTs for WoPBS
    let mut lut_1_map : HashMap<u64, u64> = HashMap::new();
    let mut lut_2_map : HashMap<u64, u64> = HashMap::new();
    for i in 0..256 {
      lut_1_map.insert(i, 2*i % 256);
      lut_2_map.insert(i, 3*i % 256);
    }
    let lut_1 = wopbs_key.generate_lut_radix(&ct, |x: u64| foo(x, &lut_1_map));
    let lut_2 = wopbs_key.generate_lut_radix(&ct, |x: u64| foo(x, &lut_2_map));

    // Multiply input ciphertexts: 2 * 4 = 8
    ct = server_key.smart_mul(&mut ct, &mut ct_2);

    // Apply LUT #1
    ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &ct);
    let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
    let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
    let test: u64 = client_key.decrypt(&lut_1_res);
    println!("Lut #1 result: {:?}", &test);
    println!("Expected result: {:?}", &lut_1_map[&8]);

    // Apply LUT #2
    let lut_2_res = wopbs_key.wopbs(&ct, &lut_2);
    let lut_2_res = wopbs_key.keyswitch_to_pbs_params(&lut_2_res);
    let test: u64 = client_key.decrypt(&lut_2_res);
    println!("Lut #2 result: {:?}", &test);
    println!("Expected result: {:?}", &lut_2_map[&8]);

}

Expected behaviour For the above code, the decrypted results do not match the expected results.

Evidence image

Configuration(please complete the following information):

IceTDrinker commented 7 months ago

try using decrypt_without_padding ?

https://docs.rs/tfhe/latest/tfhe/integer/client_key/struct.RadixClientKey.html#method.decrypt_without_padding

IceTDrinker commented 7 months ago

I see some mention of without padding in your code example ?

IceTDrinker commented 7 months ago

ah no my bad I mixed things up 😵‍💫

IceTDrinker commented 7 months ago

can confirm it reproes on latest main

IceTDrinker commented 7 months ago

mul_parallelized does not have the issue, looks like a bad carry management somewhere, could be keyswitching or the lut generation which does not handle this properly

IceTDrinker commented 7 months ago

This is not a bug on our end but a hard to use feature @cgouert

see updated code below, the main thing is to move the lut generation (which is adapted to the ciphertext degree) right before applying the wopbs

use std::collections::HashMap;
use tfhe::integer::gen_keys_radix;
use tfhe::integer::wopbs::*;
use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;

fn foo(x: u64, lut_entries: &HashMap<u64, u64>) -> u64 {
    lut_entries[&x]
}

fn main() {
    let nb_blocks: usize = 4;

    // Generate radix keys
    let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

    // Generate key for PBS (without padding)
    let wopbs_key = WopbsKey::new_wopbs_key(
        &client_key,
        &server_key,
        &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
    );

    let clear_1 = 2_u64;
    let clear_2 = 4_u64;

    // Create ciphertexts
    let mut ct = client_key.encrypt(clear_1);
    let mut ct_2 = client_key.encrypt(clear_2);

    // Generate LUTs for WoPBS
    let mut lut_1_map: HashMap<u64, u64> = HashMap::new();
    let mut lut_2_map: HashMap<u64, u64> = HashMap::new();
    for i in 0..256 {
        lut_1_map.insert(i, 2 * i % 256);
        lut_2_map.insert(i, 3 * i % 256);
    }

    let f1 = |x: u64| foo(x, &lut_1_map);
    let f2 = |x: u64| foo(x, &lut_2_map);

    // Multiply input ciphertexts: 2 * 4 = 8
    ct = server_key.smart_mul(&mut ct, &mut ct_2);

    let sanity_dec: u64 = client_key.decrypt(&ct);
    let clear_prod = clear_1 * clear_2;
    assert_eq!(sanity_dec, clear_prod);

    // Lut generation just in time
    let lut_1 = wopbs_key.generate_lut_radix(&ct, f1);
    let lut_2 = wopbs_key.generate_lut_radix(&ct, f2);

    // Apply LUT #1
    ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &ct);
    let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
    let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
    let test: u64 = client_key.decrypt(&lut_1_res);
    println!("Lut #1 result: {:?}", &test);
    println!("Expected result: {:?}", f1(clear_prod));

    // Apply LUT #2
    let lut_2_res = wopbs_key.wopbs(&ct, &lut_2);
    let lut_2_res = wopbs_key.keyswitch_to_pbs_params(&lut_2_res);
    let test: u64 = client_key.decrypt(&lut_2_res);
    println!("Lut #2 result: {:?}", &test);
    println!("Expected result: {:?}", f2(clear_prod));
}
RUSTFLAGS="-C target-cpu=native" cargo run --profile devo --features=x86_64-unix,integer,internal-keycache --example wop_smart_mul -p tfhe
   Compiling tfhe v0.6.0 (/home/***/Documents/zama/code/tfhe-rs/tfhe)
    Finished devo [optimized + debuginfo] target(s) in 8.01s
     Running `target/devo/examples/wop_smart_mul`
Lut #1 result: 16
Expected result: 16
Lut #2 result: 24
Expected result: 24
IceTDrinker commented 7 months ago

we'll consider adding an "apply_function" that just takes the function and manages the JIT LUT generation

cgouert commented 7 months ago

Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?

jimouris commented 7 months ago

@IceTDrinker Thanks for your responses!

It seems that just putting the smart_mul before the generate_lut_radix solves the issue but what if we want to do something like:

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

In this case, we'd have to run generate_lut_radix three times although the LUTs will be the same?

IceTDrinker commented 7 months ago

Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?

exactly and yes it's expected that multiplying changes the degrees of the underlying blocks :)

IceTDrinker commented 7 months ago

@IceTDrinker Thanks for your responses!

It seems that just putting the smart_mul before the generate_lut_radix solves the issue but what if we want to do something like:

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);

In this case, we'd have to run generate_lut_radix three times although the LUTs will be the same?

my best guess here is that it does a lazy lut evaluation filling as little coefficients as it can, as having more 0s in the LUT IIRC will result in less noise in the output (could be wrong on that, but if you take a trivial 0 lut you have a noiseless encryption of 0 so there may be some of that)

I don't think the LUT generation would show up on a performance measurement when compared to the runtime of a wopbs so yes you likely need a LUT generation each time, you could do a small wrapper to generate the lut and apply it right afterwards to never have issues with degrees

IceTDrinker commented 7 months ago

or rather the way the LUT is organized is not invariant by the carry bits of the ciphertexts 🤔

maybe there is something to do about it but I'm really not sure, I was not part of the team who initially worked on that, still it could be interesting to investigate if carry invariant LUTs are "easy" to write for the wopbs, the lazy eval thing is likely wrong now that I think about it

IceTDrinker commented 7 months ago

I'm going to keep this issue open as an "enhancement" issue to see if there is something to be done for the LUT or the API to limit the error prone-ness

IceTDrinker commented 6 months ago

The solution will likely to be to have an API taking a function and building the LUT just in time

Juul-Mc-Goa commented 5 months ago

Stumbled on the same issue recently.

Generating the lut manually solved the problem. The only thing changed in this function is the computation of vec_deg_basis. In my fix, all the degrees are assumed to be maximal (ie equal to the modulus).

Note that this quickfix can probably be improved:

    pub fn generate_lut_radix<F, T>(wopbs_key: &tfhe::shortint::WopbsKey, ct: &T, f: F) -> IntegerWopbsLUT
    where
        F: Fn(u64) -> u64,
        T: IntegerCiphertext,
    {
        let mut total_bit = 0;
        let block_nb = ct.blocks().len();
        let mut modulus = 1;

        //This contains the basis of each block depending on the degree
        let mut vec_deg_basis = vec![];

        for (i, _deg) in ct.moduli().iter().zip(ct.blocks().iter()) {
            modulus *= i;
            let b = f64::log2(*i as f64).ceil() as u64;
            vec_deg_basis.push(b);
            total_bit += b;
        }

        let lut_size = if 1 << total_bit < wopbs_key.param.polynomial_size.0 as u64 {
            wopbs_key.param.polynomial_size.0
        } else {
            1 << total_bit
        };
        let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(block_nb));

        let basis = ct.moduli()[0];
        let delta: u64 = (1 << 63)
            / (wopbs_key.param.message_modulus.0 * wopbs_key.param.carry_modulus.0) as u64;

        for lut_index_val in 0..(1 << total_bit) {
            let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis);
            let decoded_val = decode_radix(&encoded_with_deg_val, basis);
            let f_val = f(decoded_val % modulus) % modulus;
            let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
            for (lut_number, radix_encoded_val) in encoded_f_val.iter().enumerate().take(block_nb) {
                lut[lut_number][lut_index_val as usize] = radix_encoded_val * delta;
            }
        }
        lut
    }

Remarks

  1. This issue appears because generate_lut_radix is tied to a specific ciphertext. Thus evaluating the lut at a different ciphertext is currently unsupported.
  2. I guess one could change the signature of generate_lut_radix to take as input not a ciphertext, but an argument like maximum_index: u64. Then the function would generate a lut that can compute the value at any index between 0 and maximum_index. The for lut_index_val in ... loop would then iterate in the range 0..maximum_index, and the computation of vec_deg_basis would probably be removed.
Juul-Mc-Goa commented 5 months ago

I assume the computation of vec_deg_basis is a feature introduced to optimize the lut computation on the one ciphertext used at generate_lut_radix. This optimization is legitimate, but breaks when trying to compute over other ciphertexts.

One solution would be to mimic Rust's FnOnce/Fn pattern:

IceTDrinker commented 5 months ago

I think the reasoning when this was done was to have the smallest amount of data in the LUT and extract the smallest amount of bits when possible instead of extracting all of them, we’ll most likely go the route of the apply function and put warnings on this API

IceTDrinker commented 5 months ago

The FnOnce/Fn idea is interesting but how do you enforce running the LUT only on the ciphertext it was designed for ? On the other hand any ciphertext with the same degree profile is compatible with that LUT

Juul-Mc-Goa commented 5 months ago

I guess you can either: