facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.53k stars 278 forks source link

Tutorial 4 - failed converting model #449

Open gillesoldano opened 1 year ago

gillesoldano commented 1 year ago

Hi, I'm trying to convert a resnet18 model from pytorch to crypten with the crypten.nn.from_pytorch() function and I can't seem to make it work. The functions returns me this error:

Traceback (most recent call last):
  File "/mnt/DATI/Documents/ISIN/fhe4cv/./src/cat_and_dog/cnn_classifier/crypten_test.py", line 32, in <module>
    private_model = crypten.nn.from_pytorch(model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 57, in from_pytorch
    f = _from_pytorch_to_bytes(pytorch_model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 121, in _from_pytorch_to_bytes
    _export_pytorch_model(f, pytorch_model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 146, in _export_pytorch_model
    torch.onnx.export(pytorch_model, dummy_input, f, **kwargs)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 504, in export
    _export(
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 1506, in _export
    with exporter_context(model, training, verbose):
  File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 176, in exporter_context
    with select_model_mode_for_export(
  File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 137, in disable_apex_o2_state_dict_hook
    for module in model.modules():
AttributeError: 'collections.OrderedDict' object has no attribute 'modules'

I tried following the Tutorial 4, and saw that at some point, instead of torch.load(), the function crypten.load_from_party() was used to load the model, when using that function I get another error:

Traceback (most recent call last):
  File "/mnt/DATI/Documents/ISIN/fhe4cv/./src/cat_and_dog/cnn_classifier/crypten_test.py", line 27, in <module>
    model = crypten.load_from_party('./resnet18.pth', model_class=ResNet18, src=ALICE)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/__init__.py", line 345, in load_from_party
    raise TypeError("Unrecognized load type %s" % type(result))
TypeError: Unrecognized load type <class 'int'>

Which of those two function should I use and what is the difference ? Could the problem be how I am generating the model ?

I'm using torch 1.13.1 and crypten 0.4.1 and this is my model's summary:

ResNet18(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer3): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=512, out_features=2, bias=True)
  )
  (transforms): ImageClassification(
      crop_size=[224]
      resize_size=[256]
      mean=[0.485, 0.456, 0.406]
      std=[0.229, 0.224, 0.225]
      interpolation=InterpolationMode.BILINEAR
  )
)
gillesoldano commented 1 year ago

Hi, I was able to solve my problem using by saving the whole instance of the model and not just the state_dict, then a CUDA error appeared, from my understanding GPU is not yet supported when using the run_multiprocess decorator, this problem was solved by setting the map_location param to cpu in the torch.load function:

@mpc.run_multiprocess(world_size=2)
def main():

    model = torch.load('./resnet18.pth', map_location='cpu')
    model = crypten.load_from_party(preloaded=model, src=ALICE, model_class=ResNet18)

    # ALICE loads and encrypts model
    dummy_input = torch.empty(1, 3, input_shape, input_shape)
    private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model.encrypt(src=ALICE)

    # BOB loads and encrypts data
    data_enc = crypten.load_from_party('./test_images.pth', src=BOB)

    private_model.eval()
    output_enc = private_model(data_enc)

    output = output_enc.get_plain_text()
    print(output)

However now I get this error: ValueError: Deserialization is restricted for pickled module torchvision.models.resnet.ResNet Is there another step I have to perform in order to load the torchvision ResNet18 ?

lvdmaaten commented 1 year ago

I think what is happening here is that the model_class does not match the class of the model anymore due to changes in torchvision. Can you try running with model_class=torchvision.models.resnet.ResNet as input argument?

The reason this input argument is there is to prevent malicious code injection via the unpickler. The unpickler will refuse to unpickle anything that is not a model_class (that is, the code the user says can be trusted).

gillesoldano commented 1 year ago

Based on your commen I was able to solve the problem by adding a few lines:

crypten.common.serial.register_safe_class(ResNet)
crypten.common.serial.register_safe_class(torch.nn.Sequential)
crypten.common.serial.register_safe_class(BasicBlock)
crypten.common.serial.register_safe_class(torch.nn.modules.pooling.AdaptiveAvgPool2d)
crypten.common.serial.register_safe_class(torchvision.transforms._presets.ImageClassification)

Then I got the error ValueError: Deserialization is restricted for pickled module torchvision.transforms.functional.InterpolationMode, Therefore I tryed to solve it by adding the line crypten.common.serial.register_safe_class(torchvision.transforms.functional.InterpolationMode), However I keep getting the same error.

lvdmaaten commented 1 year ago

Interesting, I'm not sure why that class would be different. Of course, you always have the nuclear option of removing these lines of code. Don't do this in production use cases though, as now you may be susceptible to malicious code injection via the checkpoints you are loading.

gillesoldano commented 1 year ago

Even that does not work because the class torchvision.transforms.functional.InterpolationMode is not added to __SAFE_CLASSES dict, therefore I get a KeyError at line 122 of serial.py. I tried to use VGG instead of resnet and I get a different error, I think a compatibility error between ONNX and pytorch, it says:

Traceback (most recent call last):
  File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/install/.pyenv/versions/3.9.9/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
  File "/mnt/DATI/Documents/ISIN/fhe4cv/./src/cat_and_dog/cnn_classifier/crypten_test.py", line 34, in main
    private_model = crypten.nn.from_pytorch(model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 57, in from_pytorch
    f = _from_pytorch_to_bytes(pytorch_model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 128, in _from_pytorch_to_bytes
    f = _export_pytorch_model(f, pytorch_model, dummy_input)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/crypten/nn/onnx_converter.py", line 146, in _export_pytorch_model
    torch.onnx.export(pytorch_model, dummy_input, f, **kwargs)
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 504, in export
    _export(
  File "/home/install/.cache/pypoetry/virtualenvs/fhe4cv-FF0Yzqo4-py3.9/lib/python3.9/site-packages/torch/onnx/utils.py", line 1654, in _export
    raise errors.CheckerError(e)
torch.onnx.errors.CheckerError: Unrecognized attribute: ratio for operator Dropout

I tried downgrading the versions of torch, torchvision and onnx to the lowest required by crypten, however torchvision 0.9.1 is not compatible with torch 1.7.0, therefore I had to install torch 1.8.1, and then I got the same error again.

gillesoldano commented 1 year ago

I was able to use the resnet18 with crypten by training it in the same script, I am still unable to save it and then load it in another script and use it with crypten.