Open cgouert opened 7 months ago
try using decrypt_without_padding ?
I see some mention of without padding in your code example ?
ah no my bad I mixed things up 😵💫
can confirm it reproes on latest main
mul_parallelized does not have the issue, looks like a bad carry management somewhere, could be keyswitching or the lut generation which does not handle this properly
This is not a bug on our end but a hard to use feature @cgouert
see updated code below, the main thing is to move the lut generation (which is adapted to the ciphertext degree) right before applying the wopbs
use std::collections::HashMap;
use tfhe::integer::gen_keys_radix;
use tfhe::integer::wopbs::*;
use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
fn foo(x: u64, lut_entries: &HashMap<u64, u64>) -> u64 {
lut_entries[&x]
}
fn main() {
let nb_blocks: usize = 4;
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());
// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
&client_key,
&server_key,
&WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
);
let clear_1 = 2_u64;
let clear_2 = 4_u64;
// Create ciphertexts
let mut ct = client_key.encrypt(clear_1);
let mut ct_2 = client_key.encrypt(clear_2);
// Generate LUTs for WoPBS
let mut lut_1_map: HashMap<u64, u64> = HashMap::new();
let mut lut_2_map: HashMap<u64, u64> = HashMap::new();
for i in 0..256 {
lut_1_map.insert(i, 2 * i % 256);
lut_2_map.insert(i, 3 * i % 256);
}
let f1 = |x: u64| foo(x, &lut_1_map);
let f2 = |x: u64| foo(x, &lut_2_map);
// Multiply input ciphertexts: 2 * 4 = 8
ct = server_key.smart_mul(&mut ct, &mut ct_2);
let sanity_dec: u64 = client_key.decrypt(&ct);
let clear_prod = clear_1 * clear_2;
assert_eq!(sanity_dec, clear_prod);
// Lut generation just in time
let lut_1 = wopbs_key.generate_lut_radix(&ct, f1);
let lut_2 = wopbs_key.generate_lut_radix(&ct, f2);
// Apply LUT #1
ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &ct);
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
let test: u64 = client_key.decrypt(&lut_1_res);
println!("Lut #1 result: {:?}", &test);
println!("Expected result: {:?}", f1(clear_prod));
// Apply LUT #2
let lut_2_res = wopbs_key.wopbs(&ct, &lut_2);
let lut_2_res = wopbs_key.keyswitch_to_pbs_params(&lut_2_res);
let test: u64 = client_key.decrypt(&lut_2_res);
println!("Lut #2 result: {:?}", &test);
println!("Expected result: {:?}", f2(clear_prod));
}
RUSTFLAGS="-C target-cpu=native" cargo run --profile devo --features=x86_64-unix,integer,internal-keycache --example wop_smart_mul -p tfhe
Compiling tfhe v0.6.0 (/home/***/Documents/zama/code/tfhe-rs/tfhe)
Finished devo [optimized + debuginfo] target(s) in 8.01s
Running `target/devo/examples/wop_smart_mul`
Lut #1 result: 16
Expected result: 16
Lut #2 result: 24
Expected result: 24
we'll consider adding an "apply_function" that just takes the function and manages the JIT LUT generation
Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?
@IceTDrinker Thanks for your responses!
It seems that just putting the smart_mul
before the generate_lut_radix
solves the issue but what if we want to do something like:
smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
smart_mul(..)
let lut_1_res = wopbs_key.wopbs(&ct, &lut_1);
let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
In this case, we'd have to run generate_lut_radix
three times although the LUTs will be the same?
Thanks for your help with this @IceTDrinker! Just to clarify, are you saying the smart_mul changes the degree of the ciphertexts? If this is the case, the LUT doesn't work as expected because it was generated based on the original degree of a fresh encryption?
exactly and yes it's expected that multiplying changes the degrees of the underlying blocks :)
@IceTDrinker Thanks for your responses!
It seems that just putting the
smart_mul
before thegenerate_lut_radix
solves the issue but what if we want to do something like:smart_mul(..) let lut_1_res = wopbs_key.wopbs(&ct, &lut_1); let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res); smart_mul(..) let lut_1_res = wopbs_key.wopbs(&ct, &lut_1); let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res); smart_mul(..) let lut_1_res = wopbs_key.wopbs(&ct, &lut_1); let lut_1_res = wopbs_key.keyswitch_to_pbs_params(&lut_1_res);
In this case, we'd have to run
generate_lut_radix
three times although the LUTs will be the same?
my best guess here is that it does a lazy lut evaluation filling as little coefficients as it can, as having more 0s in the LUT IIRC will result in less noise in the output (could be wrong on that, but if you take a trivial 0 lut you have a noiseless encryption of 0 so there may be some of that)
I don't think the LUT generation would show up on a performance measurement when compared to the runtime of a wopbs so yes you likely need a LUT generation each time, you could do a small wrapper to generate the lut and apply it right afterwards to never have issues with degrees
or rather the way the LUT is organized is not invariant by the carry bits of the ciphertexts 🤔
maybe there is something to do about it but I'm really not sure, I was not part of the team who initially worked on that, still it could be interesting to investigate if carry invariant LUTs are "easy" to write for the wopbs, the lazy eval thing is likely wrong now that I think about it
I'm going to keep this issue open as an "enhancement" issue to see if there is something to be done for the LUT or the API to limit the error prone-ness
The solution will likely to be to have an API taking a function and building the LUT just in time
Stumbled on the same issue recently.
Generating the lut manually solved the problem. The only thing changed in this function is the computation of vec_deg_basis
. In my fix, all the degrees are assumed to be maximal (ie equal to the modulus).
Note that this quickfix can probably be improved:
pub fn generate_lut_radix<F, T>(wopbs_key: &tfhe::shortint::WopbsKey, ct: &T, f: F) -> IntegerWopbsLUT
where
F: Fn(u64) -> u64,
T: IntegerCiphertext,
{
let mut total_bit = 0;
let block_nb = ct.blocks().len();
let mut modulus = 1;
//This contains the basis of each block depending on the degree
let mut vec_deg_basis = vec![];
for (i, _deg) in ct.moduli().iter().zip(ct.blocks().iter()) {
modulus *= i;
let b = f64::log2(*i as f64).ceil() as u64;
vec_deg_basis.push(b);
total_bit += b;
}
let lut_size = if 1 << total_bit < wopbs_key.param.polynomial_size.0 as u64 {
wopbs_key.param.polynomial_size.0
} else {
1 << total_bit
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(block_nb));
let basis = ct.moduli()[0];
let delta: u64 = (1 << 63)
/ (wopbs_key.param.message_modulus.0 * wopbs_key.param.carry_modulus.0) as u64;
for lut_index_val in 0..(1 << total_bit) {
let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis);
let decoded_val = decode_radix(&encoded_with_deg_val, basis);
let f_val = f(decoded_val % modulus) % modulus;
let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
for (lut_number, radix_encoded_val) in encoded_f_val.iter().enumerate().take(block_nb) {
lut[lut_number][lut_index_val as usize] = radix_encoded_val * delta;
}
}
lut
}
generate_lut_radix
is tied to a specific ciphertext. Thus evaluating the lut at a different ciphertext is currently unsupported.generate_lut_radix
to take as input not a ciphertext, but an argument like maximum_index: u64
. Then the function would generate a lut that can compute the value at any index between 0 and maximum_index
. The for lut_index_val in ...
loop would then iterate in the range 0..maximum_index
, and the computation of vec_deg_basis
would probably be removed.I assume the computation of vec_deg_basis
is a feature introduced to optimize the lut computation on the one ciphertext used at generate_lut_radix. This optimization is legitimate, but breaks when trying to compute over other ciphertexts.
One solution would be to mimic Rust's FnOnce
/Fn
pattern:
IntegerWopbsLUTOnce
type where the current optimization is kept,IntegerWopbsLUT
type without the optimization, so that it can be applied to several ciphertexts.I think the reasoning when this was done was to have the smallest amount of data in the LUT and extract the smallest amount of bits when possible instead of extracting all of them, we’ll most likely go the route of the apply function and put warnings on this API
The FnOnce/Fn idea is interesting but how do you enforce running the LUT only on the ciphertext it was designed for ? On the other hand any ciphertext with the same degree profile is compatible with that LUT
I guess you can either:
ct
argument fed in the generate_lut_radix
method, then reuse it when evaluating the lut
Describe the bug After a smart_mul, the WoPBS does not yield the expected answer.
To Reproduce
Expected behaviour For the above code, the decrypted results do not match the expected results.
Evidence
Configuration(please complete the following information):
OS: Ubuntu 22.04
cc: @jimouris