Open gillesoldano opened 1 year ago
Hi, I was able to solve my problem using by saving the whole instance of the model and not just the state_dict, then a CUDA error appeared, from my understanding GPU is not yet supported when using the run_multiprocess decorator, this problem was solved by setting the map_location param to cpu in the torch.load function:
@mpc.run_multiprocess(world_size=2)
def main():
model = torch.load('./resnet18.pth', map_location='cpu')
model = crypten.load_from_party(preloaded=model, src=ALICE, model_class=ResNet18)
# ALICE loads and encrypts model
dummy_input = torch.empty(1, 3, input_shape, input_shape)
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.encrypt(src=ALICE)
# BOB loads and encrypts data
data_enc = crypten.load_from_party('./test_images.pth', src=BOB)
private_model.eval()
output_enc = private_model(data_enc)
output = output_enc.get_plain_text()
print(output)
However now I get this error: ValueError: Deserialization is restricted for pickled module torchvision.models.resnet.ResNet
Is there another step I have to perform in order to load the torchvision ResNet18 ?
I think what is happening here is that the model_class
does not match the class of the model
anymore due to changes in torchvision
. Can you try running with model_class=torchvision.models.resnet.ResNet
as input argument?
The reason this input argument is there is to prevent malicious code injection via the unpickler. The unpickler will refuse to unpickle anything that is not a model_class
(that is, the code the user says can be trusted).
Based on your commen I was able to solve the problem by adding a few lines:
crypten.common.serial.register_safe_class(ResNet)
crypten.common.serial.register_safe_class(torch.nn.Sequential)
crypten.common.serial.register_safe_class(BasicBlock)
crypten.common.serial.register_safe_class(torch.nn.modules.pooling.AdaptiveAvgPool2d)
crypten.common.serial.register_safe_class(torchvision.transforms._presets.ImageClassification)
Then I got the error ValueError: Deserialization is restricted for pickled module torchvision.transforms.functional.InterpolationMode
, Therefore I tryed to solve it by adding the line crypten.common.serial.register_safe_class(torchvision.transforms.functional.InterpolationMode)
, However I keep getting the same error.
Interesting, I'm not sure why that class would be different. Of course, you always have the nuclear option of removing these lines of code. Don't do this in production use cases though, as now you may be susceptible to malicious code injection via the checkpoints you are loading.
Even that does not work because the class torchvision.transforms.functional.InterpolationMode
is not added to __SAFE_CLASSES
dict, therefore I get a KeyError
at line 122 of serial.py. I tried to use VGG instead of resnet and I get a different error, I think a compatibility error between ONNX and pytorch, it says:
Traceback (most recent call last):
File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/mpc/context.py", line 30, in _launch
return_value = func(*func_args, **func_kwargs)
File "/mnt/DATI/Documents/ISIN/fhe4cv/./src/cat_and_dog/cnn_classifier/crypten_test.py", line 34, in main
private_model = crypten.nn.from_pytorch(model, dummy_input)
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 57, in from_pytorch
f = _from_pytorch_to_bytes(pytorch_model, dummy_input)
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 128, in _from_pytorch_to_bytes
f = _export_pytorch_model(f, pytorch_model, dummy_input)
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 146, in _export_pytorch_model
torch.onnx.export(pytorch_model, dummy_input, f, **kwargs)
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 504, in export
_export(
File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 1654, in _export
raise errors.CheckerError(e)
torch.onnx.errors.CheckerError: Unrecognized attribute: ratio for operator Dropout
I tried downgrading the versions of torch, torchvision and onnx to the lowest required by crypten, however torchvision 0.9.1 is not compatible with torch 1.7.0, therefore I had to install torch 1.8.1, and then I got the same error again.
I was able to use the resnet18 with crypten by training it in the same script, I am still unable to save it and then load it in another script and use it with crypten.
Hi, I'm trying to convert a resnet18 model from pytorch to crypten with the
crypten.nn.from_pytorch()
function and I can't seem to make it work. The functions returns me this error:I tried following the Tutorial 4, and saw that at some point, instead of
torch.load()
, the functioncrypten.load_from_party()
was used to load the model, when using that function I get another error:Which of those two function should I use and what is the difference ? Could the problem be how I am generating the model ?
I'm using torch 1.13.1 and crypten 0.4.1 and this is my model's summary: