OpenMined / TenSEAL

A library for doing homomorphic encryption operations on tensors
Apache License 2.0
837 stars 158 forks source link

Multiply a plaintext vector with an encrypted vector using CKKS #499

Open mun3 opened 1 month ago

mun3 commented 1 month ago

Question

Does anyone know why using CKKS to multiply a plaintext vector with an encrypted vector gives the same result as multiplying two encrypted vectors?

Further Information

AFAIK a plaintext vector cannot be multiplied with an encrypted vector using CKKS. But in the following code, the results returned are the same. Does anyone know if the plaintext vector is internally encrypted when the * operator is applied? If not, how come the two vectors can be multiplied? Thanks.

import tenseal as ts

def setup_ckks_context():
    context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
    context.global_scale = 2**40
    context.generate_galois_keys()
    return context

def encrypt_vector(context, vector):
    encrypted_vector = ts.ckks_vector(context, vector)
    return encrypted_vector

def create_mask_vector(length, isolate_position):
    mask_vector = [0] * length
    mask_vector[isolate_position] = 1
    return mask_vector

def perform_operations(encrypted_vector, encrypted_mask):
    result_vector = encrypted_vector * encrypted_mask
    result = result_vector.sum()
    return result

def decrypt_result(encrypted_vector):
    decrypted_result = encrypted_vector.decrypt()
    return decrypted_result

if __name__ == "__main__":
    original_vector = [1.2, 2.3, 3.4, 4.5, 5.6]
    isolate_position = 3

    context = setup_ckks_context()

    encrypted_vector = encrypt_vector(context, original_vector)

    mask_vector = create_mask_vector(len(original_vector), isolate_position)
    encrypted_mask = encrypt_vector(context, mask_vector)

    encrypted_result1 = perform_operations(encrypted_vector, encrypted_mask)

    encrypted_result2 = perform_operations(original_vector, encrypted_mask)

    decrypted_result1 = decrypt_result(encrypted_result1)
    print("Decrypted result:", decrypted_result1) # Decrypted result: [4.500001868981295]

    decrypted_result2 = decrypt_result(encrypted_result2)
    print("Decrypted result:", decrypted_result2) # Decrypted result: [4.50000184976374]