Closed fakub closed 1 year ago
Generate accumulator assumes a bit of padding as it's meant to be used for a PBS I believe, this might explain the issue you are seeing.
Could you post the a minified version of the code triggering the bug?
Ok just saw you linked a toy example, we'll investigate further ! But I think it's the assumption that the accumulator generation expects the LUT to be used with a PBS and so always provisions the bit of padding (not thinking about already negacyclic functions)
that's exactly the point: if I do wanna use negacyclic functions, I need en(de)crypt_without_padding
, because otherwise I cannot encrypt "negative" values, i.e., those with 1 at MSB. after a small investigation (not in the code, though), it seems to me that the problem is that en(de)crypt_without_padding
split the torus into less pieces than the parameters allow (and usually do when padding is used).
what it seems to me is that in my example with PARAM_MESSAGE_5_CARRY_0
, standard encryption occupies 6 bits (MSB), whereas encrypt_without_padding
only consumes 5 bits, which effectively multiplies input messages by 2 (~left shift). then, PBS again assumes 6 bits (1 bit of padding + 5 bits) and evaluates respective LUT. finally, decrypt_without_padding
pops out only 5 bits (again, MSB), which effectively divides the result by 2 (~right shift).
so I think that changing the encoding step of en(de)crypt_without_padding
would fix this issue, but I don't really know whether other code depends on these functions. an example covering the behavior of en(de)crypt_without_padding
together with PBS would equally be helpful, I believe :)
Sorry, forgot to answer here, my feeling is that we are maybe missing a generate_accumulator_without_padding that could be used in unison with encrypt/decrypt without padding, though other primitives using the PBS assume the presence of a padding bit, not sure how the design should evolve yet
Updating the issue with the last progress on this shared on Discord
The following code works as expected (by re-using the accumulator generator from parmesan found here: https://github.com/fakub/parmesan/blob/0f5c8fdff153134f727281112bd395ef4642a260/src/cloudovo/pbs.rs#L110)
use tfhe::core_crypto::entities::GlweCiphertext;
use tfhe::shortint::ciphertext::Degree;
use tfhe::shortint::parameters::*;
use tfhe::shortint::prelude::*;
use tfhe::shortint::server_key::LookupTableOwned;
use rayon::prelude::*;
fn gen_no_padding_acc<F>(server_key: &ServerKey, f: F) -> LookupTableOwned
where
F: Fn(u64) -> u64,
{
let mut accumulator = GlweCiphertext::new(
0u64,
server_key.bootstrapping_key.glwe_size(),
server_key.bootstrapping_key.polynomial_size(),
server_key.key_switching_key.ciphertext_modulus(),
);
let mut accumulator_view = accumulator.as_mut_view();
accumulator_view.get_mut_mask().as_mut().fill(0);
// Modulus of the msg contained in the msg bits and operations buffer
// Modulus_sup is divided by two as in parmesan
let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0 / 2;
// N/(p/2) = size of each block
let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup;
// Value of the shift we multiply our messages by
// First main change delta is re multiplied by 2 to account for the padding bit
let delta =
((1u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64) * 2;
let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();
// Tracking the max value of the function to define the degree later
let mut max_value = 0;
for i in 0..modulus_sup {
let index = i * box_size;
accumulator_u64[index..index + box_size]
.iter_mut()
.for_each(|a| {
let f_eval = f(i as u64);
*a = f_eval * delta;
max_value = max_value.max(f_eval);
});
}
let half_box_size = box_size / 2;
// Negate the first half_box_size coefficients
for a_i in accumulator_u64[0..half_box_size].iter_mut() {
*a_i = (*a_i).wrapping_neg();
}
// Rotate the accumulator
accumulator_u64.rotate_left(half_box_size);
LookupTableOwned {
acc: accumulator,
degree: Degree(max_value as usize),
}
}
fn main() {
println!("\n>>> DEMO PBS with TFHE-rs, function: x -> x + 4\n");
// custom params as shortint ignores the bit of padding when defining the message space in the parameter name
let custom_params = Parameters {
message_modulus: MessageModulus(1 << 5),
..PARAM_MESSAGE_4_CARRY_0
};
// generate keys for 5-bit msg space
let (client_key, server_key) = gen_keys(custom_params);
// create test vec from func (nice to have: generate_accumulator that inputs a LUT as a vector)
let func = |x: u64| x + 4;
// let acc = server_key.generate_accumulator(func);
let acc = gen_no_padding_acc(&server_key, func);
// -------------------------------------------------------------------------
// full loop with one extra value to spot the overlap
for i in 0..=32 {
// encrypt into the full msg domain
let c1 = client_key.encrypt_without_padding(i);
// eval func
let cf1 = server_key.apply_lookup_table(&c1, &acc);
// decrypt (include the padding bit) & print
let m1 = client_key.decrypt_without_padding(&c1);
let mf1 = client_key.decrypt_without_padding(&cf1);
println!("(i = {:2}) decrypted IN -> OUT: {:2} -> {:2}", i, m1, mf1);
}
// -------------------------------------------------------------------------
// the same thing in parallel
println!("\n\n>>> Same thing in parallel\n");
let mut c_in: Vec<_> = vec![];
let mut c_out: Vec<_> = vec![];
for i in 0..=32 {
// encrypt into the full msg domain
c_in.push(client_key.encrypt_without_padding(i));
// well, this is a bit useless
c_out.push(client_key.encrypt_without_padding(i));
}
c_out
.par_iter_mut()
.zip(c_in.par_iter())
.for_each(|(co, ci)| {
// eval func
*co = server_key.apply_lookup_table(&ci, &acc);
});
c_out
.iter()
.zip(c_in.iter().enumerate())
.for_each(|(co, (i, ci))| {
// decrypt (include the padding bit) & print
let mi = client_key.decrypt_without_padding(&ci);
let mfi = client_key.decrypt_without_padding(&co);
println!("(i = {:2}) decrypted IN -> OUT: {:2} -> {:2}", i, mi, mfi);
});
}
Sample output with negacyclic overlap starting at 16:
>>> DEMO PBS with TFHE-rs, function: x -> x + 4
(i = 0) decrypted IN -> OUT: 0 -> 4
(i = 1) decrypted IN -> OUT: 1 -> 5
(i = 2) decrypted IN -> OUT: 2 -> 6
(i = 3) decrypted IN -> OUT: 3 -> 7
(i = 4) decrypted IN -> OUT: 4 -> 8
(i = 5) decrypted IN -> OUT: 5 -> 9
(i = 6) decrypted IN -> OUT: 6 -> 10
(i = 7) decrypted IN -> OUT: 7 -> 11
(i = 8) decrypted IN -> OUT: 8 -> 12
(i = 9) decrypted IN -> OUT: 9 -> 13
(i = 10) decrypted IN -> OUT: 10 -> 14
(i = 11) decrypted IN -> OUT: 11 -> 15
(i = 12) decrypted IN -> OUT: 12 -> 16
(i = 13) decrypted IN -> OUT: 13 -> 17
(i = 14) decrypted IN -> OUT: 14 -> 18
(i = 15) decrypted IN -> OUT: 15 -> 19
(i = 16) decrypted IN -> OUT: 16 -> 28
(i = 17) decrypted IN -> OUT: 17 -> 27
(i = 18) decrypted IN -> OUT: 18 -> 26
(i = 19) decrypted IN -> OUT: 19 -> 25
(i = 20) decrypted IN -> OUT: 20 -> 24
(i = 21) decrypted IN -> OUT: 21 -> 23
(i = 22) decrypted IN -> OUT: 22 -> 22
(i = 23) decrypted IN -> OUT: 23 -> 21
(i = 24) decrypted IN -> OUT: 24 -> 20
(i = 25) decrypted IN -> OUT: 25 -> 19
(i = 26) decrypted IN -> OUT: 26 -> 18
(i = 27) decrypted IN -> OUT: 27 -> 17
(i = 28) decrypted IN -> OUT: 28 -> 16
(i = 29) decrypted IN -> OUT: 29 -> 15
(i = 30) decrypted IN -> OUT: 30 -> 14
(i = 31) decrypted IN -> OUT: 31 -> 13
(i = 32) decrypted IN -> OUT: 0 -> 4
>>> Same thing in parallel
(i = 0) decrypted IN -> OUT: 0 -> 4
(i = 1) decrypted IN -> OUT: 1 -> 5
(i = 2) decrypted IN -> OUT: 2 -> 6
(i = 3) decrypted IN -> OUT: 3 -> 7
(i = 4) decrypted IN -> OUT: 4 -> 8
(i = 5) decrypted IN -> OUT: 5 -> 9
(i = 6) decrypted IN -> OUT: 6 -> 10
(i = 7) decrypted IN -> OUT: 7 -> 11
(i = 8) decrypted IN -> OUT: 8 -> 12
(i = 9) decrypted IN -> OUT: 9 -> 13
(i = 10) decrypted IN -> OUT: 10 -> 14
(i = 11) decrypted IN -> OUT: 11 -> 15
(i = 12) decrypted IN -> OUT: 12 -> 16
(i = 13) decrypted IN -> OUT: 13 -> 17
(i = 14) decrypted IN -> OUT: 14 -> 18
(i = 15) decrypted IN -> OUT: 15 -> 19
(i = 16) decrypted IN -> OUT: 16 -> 28
(i = 17) decrypted IN -> OUT: 17 -> 27
(i = 18) decrypted IN -> OUT: 18 -> 26
(i = 19) decrypted IN -> OUT: 19 -> 25
(i = 20) decrypted IN -> OUT: 20 -> 24
(i = 21) decrypted IN -> OUT: 21 -> 23
(i = 22) decrypted IN -> OUT: 22 -> 22
(i = 23) decrypted IN -> OUT: 23 -> 21
(i = 24) decrypted IN -> OUT: 24 -> 20
(i = 25) decrypted IN -> OUT: 25 -> 19
(i = 26) decrypted IN -> OUT: 26 -> 18
(i = 27) decrypted IN -> OUT: 27 -> 17
(i = 28) decrypted IN -> OUT: 28 -> 16
(i = 29) decrypted IN -> OUT: 29 -> 15
(i = 30) decrypted IN -> OUT: 30 -> 14
(i = 31) decrypted IN -> OUT: 31 -> 13
(i = 32) decrypted IN -> OUT: 0 -> 4
hello @fakub is it ok to close with the proposed workaround for now ?
sure, at least in my use case, all works fine with the workaround :)
Describe the bug
The combination of
en(de)crypt_without_padding
together withgenerate_accumulator
gives incorrect results. E.g., withfunc = |x:u64| x + 4
, the output after decryption shows that ratherx + 2
has been evaluated.My conjecture is that omitting padding with
en(de)crypt_without_padding
is not addressed correctly ingenerate_accumulator
: it seems to be transformingx -> 2x
for the LUT evaluation (~ padding?) and back asy -> y/2
, effectively evaluating((2x) + 4) / 2 = x + 2
, which is what I observe instead of the expectedx + 4
.To Reproduce
Demonstrated in a toy example.
Evidence
Output of the toy example (parts omitted) with
PARAM_MESSAGE_5_CARRY_0
:Configuration(please complete the following information):
Cargo.lock