Open DemonWang555 opened 2 years ago
Can you post a code snippet that reproduces this error so I can try and reproduce the issue on my end?
Sure, my code is shown below:
for epoch in range(epochs):
w_input_model = []
w_output_model = []
batch_acc_train = []
batch_loss_train = []
for batch_idx, (images, labels) in enumerate(trainset):
images = crypten.cryptensor(images)
y = crypten.cryptensor(label_eye[labels]) # MPCtensor
fx = input_model(images)
output = output_model(fx)
loss_value = loss(output, y)
output_model.zero_grad()
input_model.zero_grad()
loss_value.backward()
fx.backward(fx)
output_model.update_parameters(lr)
input_model.update_parameters(lr)
batch_loss_train.append(loss_value.get_plain_text())
loss_avg_train = sum(batch_loss_train)/len(batch_loss_train)
print('Client{} Training \tLoss: {:.4f}'.format(idx, loss_avg_train))
w_input_model.append(input_model.decrypt())
w_output_model.append(output_model.decrypt())
As the code shows, the output of input_model is the input of output_model. After one epoch of training, I need to operate on the parameters of this round of model, so I need to decrypt the encrypted model. However, I encountered the aforementioned problem.
It will occur an error at input_model.decrypt()
.
Besides, I'm also wondering, if there a function can turn a crypten model which is generated from crypten.nn.from_pytorch
back to a pytorch model. Because I wnant to use the .stat_dict()
function. Or, does Crypten have a function that provides the same functionality as .stat_dict()
?
Is there any solution to this problem?
@DemonWang555 I am also getting the same error. Could you please ping me on message?
When I try to decrypt a encrypted Lenet-5 model, a mistake occured.
I don't know what I did wrong, and how can I solve this problem?