DeltaML / federated-aggregator

Participant in Federated Learning scheme that aggregates the updates from the remote trainers and calculates each trainer's contribution to the model training.
MIT License
2 stars 0 forks source link

Activate MSE validation #16

Open GFibrizo opened 4 years ago

GFibrizo commented 4 years ago

Problem

The validation of the MSE was left deactivated when activating the encryption again. T The process of calculating the MSE changed, so a refactor is needed.

Solution

Activate the validation again doing the refactors necessary.

Alternatives for the solution

We are currently considering 2 schemes for validating the MSE.

  1. First alternative:

    • In FA generate N random numbers between MIN and MAX, where MIN and MAX are the min and max values specified in the data requirements, and N is the lenght of the E(diff) vector.
    • For every i do, E(noised_diff)[i] = E(diff)[i] + noise[i] (eg. add both vector component to component).
    • Send E(noised_diff) to MB
    • MB desencrypts E(noised_diff). Then we have noised_diffs. The model buyer can't calculate the MSE because he doen't know the random values added to the vector.
    • MB sends noised_diff to FA
    • FA extracts the radom values from noised_diff (eg. for every i, diff[i] = noised_diff[i] - noise[i]). Then we have diff.
    • FA calculates MSE using diff.
    • FA sends the calculated MSE to MB along side with E(diff).
    • The MB desencrypts E(diff). Using diff he calculates MSE, and validates that the MSE received from the FA is the same.
  2. @agrojas remembers better the second alternative

agrojas commented 4 years ago

2° Alternative

image

agrojas commented 4 years ago

2° Alternative

image

This solution doesn't work

import phe as paillier
if __name__ == "__main__":
    value = 1
    PK1, pK1 = paillier.generate_paillier_keypair()
    PK2, pK2 = paillier.generate_paillier_keypair()

    x = PK1.encrypt(value)
    y = pK1.decrypt(x)
    z = PK2.encrypt(y)
    w = pK2.decrypt(z)
    X = PK1.encrypt(w)
    print("Verify x==X")
    print(diff)
GFibrizo commented 4 years ago

First alternative

This solution works. The current code is in the model_buyer source

import numpy as np
from commons.encryption.encryption_service import EncryptionService

def test_mse_validation_scheme():
    encryption_service = EncryptionService(is_active=True)
    public_key, private_key = encryption_service.generate_key_pair(1024)
    encryption_service.set_public_key(public_key.n)

    # Actor 1 - First step - generating the serialized encrypted array with noise
    # Sending the array to Actor 2
    array = np.asarray([34.0, 56.0, 1.1, 0.56, 88.9, 0.112, 22.13])
    noise = np.random.randint(array.min(), array.max(), (1, array.size))[0]
    encrypted_noised_array = encryption_service.encrypt_collection(array + noise)
    serialized_encrypted_noised_array = encryption_service.get_serialized_collection(encrypted_noised_array)

    # Actor 2 - Second step - deserializing and decrypting the array with noise received from Actor 1
    # Sending the array to Actor 1
    deserialized_encrypted_noised_array = encryption_service.get_deserialized_collection(serialized_encrypted_noised_array)
    deserialized_decrypted_noised_array = encryption_service.decrypt_collection(deserialized_encrypted_noised_array)

    # Actor 1 - Third step - Using the decrypted array, removing the noise, calculates the square of
    # each component and then the mean of the array
    # Sending the value and the noise to Actor 2 for validation
    noised_array = np.asarray(deserialized_decrypted_noised_array)
    array2 = noised_array - noise
    value = np.mean(array2 ** 2)

    # Actor 2 - Fourth step - Substracting noise from decrypted array, doing same calculation that Actor 1
    # Validating value sent by Actor 1 with value calculated here
    noised_array2 = np.asarray(deserialized_decrypted_noised_array)
    array3 = noised_array2 - noise
    value2 = np.mean(array3 ** 2)

    assert value == value2
    assert value == np.mean(array ** 2)

test_mse_validation_scheme()