facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.52k stars 278 forks source link

How to get return value from @mpc.run_multiprocess? #486

Open DongHan9722 opened 1 year ago

DongHan9722 commented 1 year ago
import torch
import crypten.mpc as mpc
import crypten
import crypten.communicator as comm

A = torch.tensor([1.0, 2.0, 3.0])
B = torch.tensor([4.0, 5.0, 6.0])

@mpc.run_multiprocess(world_size=2,maxsize=2**30)
def examine_arithmetic_shares(A,B):
    x_enc = crypten.cryptensor(A, src = 0)
    y_enc = crypten.cryptensor(B, src = 1)

    print(x_enc.share)
    print(y_enc.share)
    res = l2_norm(x_enc.share.float(),0)
    print(res)
    return res

res = examine_arithmetic_shares(A,B)

how can I get x_enc.share or res from here? It looks like I cannot return them.

The output is following:

tensor([ 5023079003317308407, -6900536229428560026, -6471207578338020124])tensor([-5023079003316259831,  6900536229430657178,  6471207578341165852])

tensor([-1078915575736659511, -7309411383013801875,  1627045180772323973])tensor([ 1078915575740853815,  7309411383019044755, -1627045180766032517])

tensor([ 0.4690, -0.6442, -0.6042])tensor([-0.4690,  0.6442,  0.6042])

     19     return res
     20 
---> 21 output = examine_arithmetic_shares(A,B)

/home/benchmark/CrypTen/venv/lib/python3.7/site-packages/crypten-0.4.0-py3.7.egg/crypten/mpc/context.py in wrapper(*args, **kwargs)
    107             return_values = []
    108             while not queue.empty():
--> 109                 return_values.append(queue.get())
    110 
    111             return [value for _, value in sorted(return_values, key=itemgetter(0))]

/usr/lib/python3.7/multiprocessing/queues.py in get(self, block, timeout)
    111                 self._rlock.release()
    112         # unserialize the data after having released the lock
--> 113         return _ForkingPickler.loads(res)
    114 
    115     def qsize(self):

/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/reductions.py in rebuild_storage_fd(cls, df, size)
    280 
    281 def rebuild_storage_fd(cls, df, size):
--> 282     fd = df.detach()
...
--> 619         s.connect(address)
    620         return Connection(s.detach())
    621 

ConnectionRefusedError: [Errno 111] Connection refused
Timo9Madrid7 commented 5 months ago

try return res.tolist()

btw, if your return is a long list, you can refer to this #354