facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.54k stars 280 forks source link

Model decryption fails. #421

Open DemonWang555 opened 2 years ago

DemonWang555 commented 2 years ago

When I try to decrypt a encrypted Lenet-5 model, a mistake occured.

  File "/home/by/anaconda3/lib/python3.7/site-packages/crypten/nn/module.py", line 175, in encrypt
    requires_grad = param.requires_grad
AttributeError: 'MPCTensor' object has no attribute 'requires_grad'

I don't know what I did wrong, and how can I solve this problem?

lvdmaaten commented 2 years ago

Can you post a code snippet that reproduces this error so I can try and reproduce the issue on my end?

DemonWang555 commented 2 years ago

Sure, my code is shown below:

for epoch in range(epochs):
    w_input_model = []
    w_output_model = []

    batch_acc_train = []
    batch_loss_train = []

    for batch_idx, (images, labels) in enumerate(trainset):
        images = crypten.cryptensor(images)
        y = crypten.cryptensor(label_eye[labels])  # MPCtensor

        fx = input_model(images)
        output = output_model(fx)

        loss_value = loss(output, y)

        output_model.zero_grad()
        input_model.zero_grad()

        loss_value.backward()
        fx.backward(fx)

        output_model.update_parameters(lr)
        input_model.update_parameters(lr)

        batch_loss_train.append(loss_value.get_plain_text())

    loss_avg_train = sum(batch_loss_train)/len(batch_loss_train)
    print('Client{} Training \tLoss: {:.4f}'.format(idx, loss_avg_train))

    w_input_model.append(input_model.decrypt())
    w_output_model.append(output_model.decrypt())

As the code shows, the output of input_model is the input of output_model. After one epoch of training, I need to operate on the parameters of this round of model, so I need to decrypt the encrypted model. However, I encountered the aforementioned problem.

It will occur an error at input_model.decrypt().

Besides, I'm also wondering, if there a function can turn a crypten model which is generated from crypten.nn.from_pytorch back to a pytorch model. Because I wnant to use the .stat_dict() function. Or, does Crypten have a function that provides the same functionality as .stat_dict()?

DemonWang555 commented 2 years ago

Is there any solution to this problem?

hemant5454 commented 1 year ago

@DemonWang555 I am also getting the same error. Could you please ping me on message?