delta-mpc / crypten_vfl_demo

vertical federated learning demo with crypten
MIT License
44 stars 6 forks source link

代码中的一些问题 #6

Open qingqbaby opened 2 years ago

qingqbaby commented 2 years ago

在您的项目中有个函数: `def load_encrypt_tensor(filename: str) -> crypten.CrypTensor: local_tensor = load_local_tensor(filename) # 加载本地数据 rank = comm.get().get_rank() # 本机的rank count = local_tensor.shape[0]

encrypt_tensors = []
for i, (name, feature_size) in enumerate(zip(names, feature_sizes)):
    if rank == i:
        # 加密本机的数据,分享给其他参与方
        assert local_tensor.shape[1] == feature_size, \
            f"{name} feature size should be {feature_size}, but get {local_tensor.shape[1]}"
        tensor = crypten.cryptensor(local_tensor, src=i) # src标识数据的持有方
    else:
        # 从其他参与方中接收加密的数据
        # 假输入,需要与实际的数据维度一致
        dummy_tensor = torch.zeros((count, feature_size), dtype=torch.float32)
        tensor = crypten.cryptensor(dummy_tensor, src=i)
    encrypt_tensors.append(tensor)

res = crypten.cat(encrypt_tensors, dim=1)  # 将所有特征数据拼接起来
return res`

我所理解的该函数的作用是用来进行本地数据加密并且接收来自其他客户端的数据加密结果,但是我看其中您注释了假输入,那么请问真的输入在哪里? 即dummy_tensor = torch.zeros((count, feature_size), dtype=torch.float32)这个我理解的是对未拥有的特征假设为全0数据,然后进行加密,是为了保证维度保持一致的假输入

mh739025250 commented 2 years ago

真输入来自持有这个数据的参与方。

    if rank == i:
        # 加密本机的数据,分享给其他参与方
        assert local_tensor.shape[1] == feature_size, \
            f"{name} feature size should be {feature_size}, but get {local_tensor.shape[1]}"
        tensor = crypten.cryptensor(local_tensor, src=i) # src标识数据的持有方

在这里,持有数据的参与方,也就是rank=i的参与方,会通过crypten.cryptensor把数据秘密分享给网络中其他的参与方。而其他的参与方虽然没有数据,但是也需要一个cryptensor来接收别人分享来的数据,所以有了dummy_tensor那一段。