alibaba-edu / mpc4j

Apache License 2.0
211 stars 42 forks source link

采用KKRT16 (4 hash)进行求交集,结果不正确 #19

Closed aaaaaa1 closed 6 months ago

aaaaaa1 commented 1 year ago

1、ArrayList<Set> sets = PsoUtils.generateBytesSets(serverSetSize, clientSetSize, ELEMENT_BYTE_LENGTH)改为数组,如下图: a6761ccb97ffeed50ee5e2fed74f8d66 2、在求交方法handleServerPrf进行结果输出,如下图: image

您好,要怎么改才可以得到正确结果呢?

liuweiran900217 commented 1 year ago

每个协议的底层运行原理不一样,不应该通过修改协议的方式尝试更换输入的类型。目前mpc4j协议的实现使用了Java的泛型来支持如String等类型的数据。具体来说,你可以把把测试用例中的PsiServerThread改为:

class PsiServerThread extends Thread {
    /**
     * PSI服务端
     */
    private final PsiServer<String> server;
    /**
     * 服务端集合
     */
    private final Set<String> serverElementSet;
    /**
     * 客户端元素数量
     */
    private final int clientElementSize;

    PsiServerThread(PsiServer<String> server, Set<String> serverElementSet, int clientElementSize) {
        this.server = server;
        this.serverElementSet = serverElementSet;
        this.clientElementSize = clientElementSize;
    }

    @Override
    public void run() {
        try {
            server.init(serverElementSet.size(), clientElementSize);
            server.psi(serverElementSet, clientElementSize);
        } catch (MpcAbortException e) {
            e.printStackTrace();
        }
    }
}

把测试用例中的PsiClientThread改为:

class PsiClientThread extends Thread {
    /**
     * PSI客户端
     */
    private final PsiClient<String> client;
    /**
     * 客户端集合
     */
    private final Set<String> clientElementSet;
    /**
     * 服务端元素数量
     */
    private final int serverElementSize;
    /**
     * 客户端交集
     */
    private Set<String> intersectionSet;

    PsiClientThread(PsiClient<String> client, Set<String> clientElementSet, int serverElementSize) {
        this.client = client;
        this.clientElementSet = clientElementSet;
        this.serverElementSize = serverElementSize;
    }

    Set<String> getIntersectionSet() {
        return intersectionSet;
    }

    @Override
    public void run() {
        try {
            client.init(clientElementSet.size(), serverElementSize);
            intersectionSet = client.psi(clientElementSet, serverElementSize);
        } catch (MpcAbortException e) {
            e.printStackTrace();
        }
    }
}

再把测试用例中的PsiTest改为类似下面的形式:

PsiServer<String> server = PsiFactory.createServer(firstRpc, secondRpc.ownParty(), config);
PsiClient<String> client = PsiFactory.createClient(secondRpc, firstRpc.ownParty(), config);
server.setParallel(parallel);
client.setParallel(parallel);
int randomTaskId = Math.abs(SECURE_RANDOM.nextInt());
server.setTaskId(randomTaskId);
client.setTaskId(randomTaskId);
try {
    LOGGER.info("-----test {},server_size = {},client_size = {}-----",
        server.getPtoDesc().getPtoName(), serverSetSize, clientSetSize
    );
    // 在这里修改你期望的输入
    Set<String> serverSet = ...;
    Set<String> clientSet = ...;
    PsiServerThread serverThread = new PsiServerThread(server, serverSet, clientSet.size());
    PsiClientThread clientThread = new PsiClientThread(client, clientSet, serverSet.size());
    StopWatch stopWatch = new StopWatch();
    // start
    stopWatch.start();
    serverThread.start();
    clientThread.start();
    // stop
    serverThread.join();
    clientThread.join();
    stopWatch.stop();
    long time = stopWatch.getTime(TimeUnit.MILLISECONDS);
    stopWatch.reset();
    // 在这里打印结果并对比
    ........
    // destroy
    new Thread(server::destroy).start();
    new Thread(client::destroy).start();
} catch (InterruptedException e) {
    e.printStackTrace();
}

从提问看,可能你需要了解一下Java泛型的撰写方法。

liuweiran900217 commented 1 year ago

不知问题是否已经解决,是否可以关闭此issue?