zama-ai / concrete-ml

Concrete ML: Privacy Preserving ML framework built on top of Concrete, with bindings to traditional ML frameworks.
Other
863 stars 125 forks source link

Adding encrypted training for other ML models and DL models #789

Open riemanli opened 2 weeks ago

riemanli commented 2 weeks ago

Feature request

From the doc of encrypted training, only SGDClassifier is mentioned. I would like to train other ML/DL models on encrypted data and also would love to contribute.

Here are some questions after reading some of the codes related to encrypted training :

  1. What are the reasons that there is no encrypted training for other ML/DL models? Is it because there is some limitation in either concrete or concrete-ml that blocks this development? If so, what are those limitations?
  2. Some potential constraints I observed from the code that does encrypted learning on SGDClassifier: (1) Parameter range has to be preset $~$ * Is this inevitable due to overflowing during FHE computation?

    (2) Floating point distribution of input has to be similar $~$ * Could you elaborate more on this?

    (3) Learning rate == 1 $~$ * Does it mean we cannot have arbitrary learning rates?

    It would be much appreciated if you could explain the causes of them.

Motivation

I would like to contribute to encrypted learning for other models. \ \ Thanks a lot in advance.

jfrery commented 2 weeks ago

Hi @riemanli,

We would be very happy for you to contribute! Let me try to answer all you question the best I can.

Before I start let me just explain how you can train models using torch and then our compile_torch_model from concrete-ml. This compile function takes everything within the forward of the torch model and convert it to the FHE equivalent. So, what we did is, instead of writing only the inference part, we also added the backward part. The model, instead of returning prediction, returns it's updated parameters.

Here is the code in question -> https://github.com/zama-ai/concrete-ml/blob/main/src/concrete/ml/sklearn/_fhe_training_utils.py#L23-L97.

Then we build around this.

What are the reasons that there is no encrypted training for other ML/DL models? Is it because there is some limitation in either concrete or concrete-ml that blocks this development? If so, what are those limitations?

Internally we did train few layers MLP following the method I described. It works to some extent. The main limitation is probably the limited precision for gradients / weights. We have 6 bits of precision for the gradients currently, this is not ideal considering we use uniform quantization. As the model trains and converges, the gradients tend to be orders of magnitude smaller.

That being said there are different ways to implement FHE training. We currently return the parameter of the model after each iteration. But we could certainly imagine returning the gradients and accumulate them either in FHE or on the client's machine which would allow use to have much more precision on the gradients.

Some potential constraints I observed from the code that does encrypted learning on SGDClassifier:

(1) Parameter range has to be preset

  • Is this inevitable due to overflowing during FHE computation?

Yes so this was a very counterintuitive parameter to have but we did not find a better way for now. We have a very limited precision so we need to be as tight as possible to the values that parameters can take.

(2) Floating point distribution of input has to be similar

  • Could you elaborate more on this?

I suppose you are referring to the fine-tuning part where we reuse the same FHE training circuit to learn a different dataset.

Here we just meant that the min max values of inputs should be the same basically. If they are not there will be a loss of information (because we clip on the previous min max) or precision (because we would not use effectively the available quantizer precision).

(3) Learning rate == 1

  • Does it mean we cannot have arbitrary learning rates?

The issue with learning rates is related to how numbers are represented and manipulated. Let me try to explain this clearly.

First, you have to know that when we do an addition or subtraction, both numbers must use the same level of precision. If they don't match, the less precise number is "upscaled" to match the more precise one. The learning rate is applied to the gradient (multiplied together). At this point everything is fine.

The issue occurs when updating the weights by subtracting the gradient. If the learning rate is small, the gradients become much more precise than the weights. This forces the weights to be upscaled to match the gradients' precision. Since weight are 6 bits and their range are predefined, with a small learning rate, the change to the weights becomes so small that it's smaller than one "step" in the quantized representation (input of the circuit). The parameter quantizer essentially rounds this tiny change to zero resulting in no actual update to the weights.

So this means, the way we do this cannot support small learning rate. But as I said above, the method can be updated e.g. accumulate gradients to mitigate this limitation.

riemanli commented 1 week ago

Thanks so much for such a detailed and clear explanation, @jfrery! The information is great and sufficient for me to start something. I will reach out again when I encounter issues or questions during development.

riemanli commented 1 day ago

Hi @jfrery, what's the best way to add a backward process for NN? I don't want to reinvent the wheel for backward, so I tried to embed the Torch optimizer into forward, but it seems Concrete compiler doesn't support that.

My experiment code

class TinyCNNTrainer(nn.Module):

    def __init__(self, n_classes) -> None:
        super().__init__()
        self.trainee = TinyCNN(n_classes)
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.trainee.parameters())

    def forward(self, x, y):
        self.trainee.train()

        output = self.trainee(x)

        loss = self.loss(output, y.reshape(-1).long())
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        return output
jfrery commented 1 day ago

Hi @riemanli,

You've got the right idea on how to do this but there are a few problems:

When you use compile_torch_model, what happens is that we trace the input until the output of the circuit and create a FHE circuit that reproduces the output. Anything that does not impact the input is just discarded.

If you want this to work, you will need to provide the parameters of the model as an input of the forward and output either the gradients or the updated parameters.

This is what our logistic regression training forward looks like:

    def forward(
        self,
        features: torch.Tensor,
        targets: torch.Tensor,
        weights: torch.Tensor,
        bias: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

I am curious to see what problem you encounter.