facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.53k stars 279 forks source link

Loading a CrypTen saved model. #370

Open BakingBrains opened 2 years ago

BakingBrains commented 2 years ago

Can I load a model saved using CrypTen without the model architecture?

I loaded a state dict file and then save that model using CrypTen, but while loaded from crypten.load(), it is asking for the model architecture.

Any suggestions?

Thank You

knottb commented 2 years ago

The answer differs depending on if you are using crypten.load() or crypten.load_from_party().

My impression is you shouldn't need an architecture if you are using crypten.load as the code is quite minimal (it just loads whatever was saved). See the code here.

However, if you use crypten.load_from_party(), you need to supply a class for the loaded type. This is for multiple reasons. First, the parties that are receiving data need to know what data structure to populate - values are communicated, but not data structure. Second, for security reasons, the parties receiving data have a white-list of classes that it will accept from other parties. Not providing a known class type would expose a vulnerability in this code.

Let me know if this helps.

BakingBrains commented 2 years ago

Thank you @knottb. What I am doing is:

checkpoint = torch.load('path/to/checkpoint')
model = MyModel(num_classes=num_class)
model.load_state_dict(checkpoint)
model.cuda()
model.eval()

dummy_input = torch.empty((1, 3, 512, 640))
dummy_input = dummy_input.cuda()

private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model = private_model.cuda()
private_model.encrypt()

print("Model successfully encrypted:", private_model.encrypted)

data_file = "Encrypted.pth"
crypten.save(private_model.state_dict(), data_file)

If I do:

model = crypten.load('Encrypted5.pth')

This throws model class/architecture required error. Am I doing this right?

Any suggestions?

Thank you.

knottb commented 2 years ago

I can't seem to reproduce this on my end (though I don't have cuda enabled at the moment). What is the exact error message that you are getting?

Also try saving from cpu rather than CUDA to see if that will solve the issue.

Additionally, note that If you are using multiple parties, your file will be overwritten during crypten.save(). Your data_file string should contain unique filepaths for each party - Use data_file = f"encrypted_{rank}.pth" where rank can be found using rank = crypten.communicator.get().get_rank().

BakingBrains commented 2 years ago

@knottb Error message: File "D:/pycharmprojects/model_encoding/encoding/test_cnn/check_encryption.py", line 63, in <module> model = crypten.load('D:/model/test_cnn/Encrypted_0.pth') File "D:\pycharmprojects\model_encoding\crypten\__init__.py", line 352, in load obj = load_closure(f) File "D:\Environments\model_encoding\lib\site-packages\torch\serialization.py", line 594, in load return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) File "D:\Environments\model_encoding\lib\site-packages\torch\serialization.py", line 853, in _load result = unpickler.load() AttributeError: Can't get attribute 'MyConvNet' on <module '__main__' from 'D:/pycharmprojects/model_encoding/encoding/test_cnn/check_encryption.py'>

MyConvNet is the model architecture class name. When I do this: model = crypten.load('D:/model/test_cnn/Encrypted_0.pth')

This is how I am saving the model. (in CPU this time)

checkpoint = torch.load('D:/CNN/weights/best_checkpoint.pth', map_location='cpu') model = MyConvNet(num_classes=num_class) model.load_state_dict(checkpoint) model.eval() dummy_input = torch.empty((1, 3, 512, 640)) private_model = crypten.nn.from_pytorch(model, dummy_input) private_model.encrypt() print("Model successfully encrypted:", private_model.encrypted) rank = crypten.communicator.get().get_rank() data_file = f"D:/model/test_cnn/Encrypted_{rank}.pth" crypten.save(private_model, data_file) print("Encrypted model saved.")