facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.49k stars 273 forks source link

Numerical overflow during inference #376

Open AbbyLuHui opened 2 years ago

AbbyLuHui commented 2 years ago

Hi, I am trying to import a pretrained pytorch MLP model into crypten. However, there seems to be numerical overflow issues using both CPU and GPU when converting output_enc to plaintext. I am wondering what might cause the issue? Thank you!

model = Net()
pretrained_dict = torch.load(pretrain_dir)
model_dict = model.state_dict()
model_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(model_pretrained_dict)
model.load_state_dict(model_dict)
dummy_input = torch.empty(1, 256)
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.eval()

feat1 = crypten.load_from_party(filenames[0],src=0)
feat2 = crypten.load_from_party(filenames[1],src=1)
feat3 = crypten.load_from_party(filenames[2],src=2)
feat4 = crypten.load_from_party(filenames[3],src=3)
feat5 = crypten.load_from_party(filenames[4],src=4)
private_model.encrypt()

x_test = crypten.cat([feat1, feat2, feat3, feat4, feat5[:, 0]], dim=1)
output_enc = private_model(x_test)
output = output_enc.get_plain_text()
crypten.print(output) # numerical overflow occurs here for output.
knottb commented 2 years ago

There are a few things I'd want to know to be able to diagnose this.

  1. What kind of model is Net()?
  2. What are the ranges of values in x_test and model.parameters()?
  3. Does x_test.get_plaint_text() return as expected?
  4. How many processes (or parties) are being used during runtime?

Depending on the answers to these questions I may have some follow-ups

AbbyLuHui commented 2 years ago

Thanks so much for the reply! Here are some additional information.

  1. Net() is a basic MLP model with 4 Linear layers and 3 ReLU activations in between. The largest weight matrix is 256 x 128 although I may increase the number of parameters in the future.
  2. Both are small float32 numbers. By looking at torch.max and torch.min, the values range in between -1.9 to 1.9 for both x_test and model.parameters().
  3. Yes! x_test.get_plain_text() and model.decrypt() gives the expected values.
  4. Currently it uses 5-party computation (5 processes).

Looking forward to the follow-ups. :)

knottb commented 2 years ago

Thanks for the information. A few more follow-ups to try to troubleshoot.

  1. Does this work if you reduce to a world size of 1? (2? 3? 4? Note: change the src argument accordingly)
  2. Have you changed the default value of cfg.encoder.precision_bits, cfg.mpc.provider or cfg.mpc.protocol?
  3. Other possible source of outputs that look like numerical overflow can be: a. Gradient explosion - This is inference mode so this shouldn't occur b. Approximations - MLP + ReLU should not contain any approximations c. Value overflow - Unlikely since your inputs and weights are bounded d. Truncation overflow - Should happen with very low probability unless our code is broken e. Share misalignment - This is often the culprit when loading from multiple sources. Since x_test.get_plain_text() returns properly, this would only occur if the model weights are misaligned, which seems unlikely in this code. (To check you could verify that torch.nn.util.parameters_to_vector(private_model.parameters()).get_plain_text() == torch.nn.util.parameters_to_vector(model.parameters())

If these do not solve the problem, it is likely a bug in CrypTen code which we should try to identify.

AbbyLuHui commented 2 years ago

Thanks a lot for the detailed comments. Here are more troubleshooting information.

  1. The code works with world_size = 1. However, it does not work with word_size = 2, 3, 4, 5.
  2. Nope I have not played with the default values. Any things I could try?
  3. I agree that (a) - (d) seems unlikely. For part (e) torch.nn.util.parameters_to_vector(private_model.parameters()).get_plain_text() == torch.nn.util.parameters_to_vector(model.parameters()) returned mostly False. The values have minor differences in the order of magnitude 1e^{-5} - 1e^{-7}. The same applies for the world_size=1 case, which is working fine.

For multiprocessing, I am using a script similar to https://github.com/facebookresearch/CrypTen/blob/f4cbdfc685d9064f45a5654dee9f3809f6d93e7f/examples/multiprocess_launcher.py