Open qingqbaby opened 2 years ago
真输入来自持有这个数据的参与方。
if rank == i:
# 加密本机的数据,分享给其他参与方
assert local_tensor.shape[1] == feature_size, \
f"{name} feature size should be {feature_size}, but get {local_tensor.shape[1]}"
tensor = crypten.cryptensor(local_tensor, src=i) # src标识数据的持有方
在这里,持有数据的参与方,也就是rank=i
的参与方,会通过crypten.cryptensor
把数据秘密分享给网络中其他的参与方。而其他的参与方虽然没有数据,但是也需要一个cryptensor
来接收别人分享来的数据,所以有了dummy_tensor
那一段。
在您的项目中有个函数: `def load_encrypt_tensor(filename: str) -> crypten.CrypTensor: local_tensor = load_local_tensor(filename) # 加载本地数据 rank = comm.get().get_rank() # 本机的rank count = local_tensor.shape[0]
我所理解的该函数的作用是用来进行本地数据加密并且接收来自其他客户端的数据加密结果,但是我看其中您注释了假输入,那么请问真的输入在哪里? 即
dummy_tensor = torch.zeros((count, feature_size), dtype=torch.float32)
这个我理解的是对未拥有的特征假设为全0数据,然后进行加密,是为了保证维度保持一致的假输入