secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
243 stars 106 forks source link

[Bug]: emulator.run运行时间有问题 #921

Open zhou-pz opened 1 week ago

zhou-pz commented 1 week ago

Issue Type

Performance

Modules Involved

MPC protocol, SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.3dev20241009

OS Platform and Distribution

Ubuntu 22.04

Python Version

3.9

Compiler Version

GCC 11.2.1

Current Behavior?

下面代码中floyd_opt函数里最后一句代码 “batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)” 导致程序跑得很慢,去掉后程序跑得很快,但是这是简单的赋值操作,不会带来这么大开销,是什么原因呢?输出日志和代码如下。

  1. 有batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)时输出日志:

    [2024-11-24 12:01:18.766] [info] [api.cc:172] [Profiling] SPU execution floyd_opt completed, input processing took 1.28e-06s, execution took 76.40721284s, output processing took 2.3442e-05s, total time 76.407237562s.
    [2024-11-24 12:01:19.898] [info] [api.cc:220] HLO profiling: total time 76.390170644
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.while, executed 100 times, duration 76.100738252s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.custom_call: spu.gather, executed 100 times, duration 0.26589784s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.concatenate, executed 196 times, duration 0.020669029s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.slice, executed 394 times, duration 0.001173913s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.reshape, executed 100 times, duration 0.000812756s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.constant, executed 8 times, duration 0.000444215s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.free, executed 792 times, duration 0.000280899s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - pphlo.convert, executed 1 times, duration 0.00015374s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:220] HAL profiling: total time 21.108059529
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - i_less, executed 1455400 times, duration 8.139717218s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - logical_not, executed 970200 times, duration 4.869683125s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - _mux, executed 485100 times, duration 4.673866196s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - _and, executed 1455300 times, duration 2.597090457s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - i_add, executed 485100 times, duration 0.827552852s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - seal, executed 1 times, duration 0.000149681s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:220] MPC profiling: total time 33.847582661000004
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - update_slice, executed 485100 times, duration 16.305582855s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - add_pp, executed 2910700 times, duration 2.964376123s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - extract_slice, executed 3396094 times, duration 2.585630088s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - reshape, executed 4366000 times, duration 2.440487449s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - negate_p, executed 2425600 times, duration 2.225913358s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - and_pp, executed 1455300 times, duration 1.603299254s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - broadcast, executed 2425603 times, duration 1.238344766s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - concatenate, executed 485296 times, duration 1.139825704s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - msb_p, executed 1455400 times, duration 1.137026837s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - add_aa, executed 970200 times, duration 0.70372609s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - make_p, executed 970200 times, duration 0.602565889s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - mul_ap, executed 485100 times, duration 0.353910826s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - transpose, executed 485100 times, duration 0.296557429s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - negate_a, executed 485100 times, duration 0.250191463s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:223] - p2a, executed 1 times, duration 0.00014453s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19.898] [info] [api.cc:233] Link details: total send bytes 0, recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 12:01:19,900] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
    [2024-11-24 12:01:19,902] [ForkServerProcess-1] RunR: builtin_fetch_object at node:0
    [2024-11-24 12:01:19,903] [ForkServerProcess-2] RunR: builtin_fetch_object at node:1
    [2024-11-24 12:01:19,905] [ForkServerProcess-3] RunR: builtin_fetch_object at node:2
    [2024-11-24 12:01:19,907] Shutdown multiprocess cluster...
  2. 去掉batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)时输出日志:

    
    [2024-11-24 11:24:13.057] [info] [api.cc:172] [Profiling] SPU execution floyd_opt completed, input processing took 4.11e-07s, execution took 0.00039479s, output processing took 1.413e-06s, total time 0.000396614s.
    [2024-11-24 11:24:13.057] [info] [api.cc:220] HLO profiling: total time 3.1869e-05
    [2024-11-24 11:24:13.057] [info] [api.cc:223] - pphlo.constant, executed 1 times, duration 3.1869e-05s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 11:24:13.057] [info] [api.cc:220] HAL profiling: total time 0
    [2024-11-24 11:24:13.057] [info] [api.cc:220] MPC profiling: total time 9.837e-06
    [2024-11-24 11:24:13.057] [info] [api.cc:223] - broadcast, executed 1 times, duration 9.837e-06s, send bytes 0 recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 11:24:13.057] [info] [api.cc:233] Link details: total send bytes 0, recv bytes 0, send actions 0, recv actions 0
    [2024-11-24 11:24:13,057] [ForkServerProcess-4] RunR: builtin_fetch_object at node:3
    [2024-11-24 11:24:13,061] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
    [2024-11-24 11:24:13,062] [ForkServerProcess-1] RunR: builtin_fetch_object at node:0
    [2024-11-24 11:24:13,065] [ForkServerProcess-2] RunR: builtin_fetch_object at node:1
    [2024-11-24 11:24:13,066] [ForkServerProcess-3] RunR: builtin_fetch_object at node:2
    [2024-11-24 11:24:13,067] Shutdown multiprocess cluster...

**代码:**

def floyd_opt( dist ):

n = len(dist)
batch_2 = dist

for k in range(n):
    batch_2 = jnp.delete(batch_2, k, axis=0)
    col_k_without_dkk = batch_2[:, k]
    batch_2 = jnp.delete(batch_2, k, axis=1)
    dist_ik = jnp.zeros_like(batch_2)  
    dist_kj = jnp.zeros_like(batch_2)

    for i in range(n-1):
        if(i < k):
            dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i][k]))
        else:
            dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i+1][k]))

    for j in range(n-1):
        if(j < k):
            dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j]))
        else:
            dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j+1]))

    # 把上三角拿出来算
    indices = numpy.triu_indices(batch_2.shape[0], k=1)
    batch_2_upper_triangle = batch_2[indices]
    dist_ik_upper_triangle = dist_ik[indices]
    dist_kj_upper_triangle = dist_kj[indices]

    # # 把上三角放回去
    batch_2 = jnp.zeros_like(dist)
    batch_2 = batch_2.at[indices].set(batch_2_upper_triangle) ############ 这行代码对运行时间影响很大!!!!!

return batch_2

def emul_cpz(mode: emulation.Mode.MULTIPROCESS):

try:
    # bandwidth and latency only work for docker mode
    emulator = emulation.Emulator(
        emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
    )
    emulator.up()

    # 设置样本数量和维度
    num_samples = 100
    Knn = np.random.rand(num_samples, num_samples)
    Knn = (Knn + Knn.T) / 2
    Knn[Knn == 0] = np.inf
    np.fill_diagonal(Knn, 0)
    Knn=emulator.seal(Knn)

    # floyd_opt
    shortest_paths_floyd= emulator.run(floyd_opt)(Knn)

finally:
    emulator.down()

if name == "main": emul_cpz(emulation.Mode.MULTIPROCESS)


### Standalone code to reproduce the issue

```Python
如上

Relevant log output

如上
deadlywing commented 1 week ago

hello,主要是因为如果那一行被comment,则编译器发现这个函数的return为常量0的array; 如果没被comment,才会按你定义的函数执行,至于为什么这么慢,主要是因为实现里的循环是O(n^2)的

zhou-pz commented 1 week ago

hello,我把里面的两个循环注释掉了时间没减少,还是需要78秒(dist和batch_2只是100*100的)

    for i in range(n-1):
        if(i < k):
            dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i][k]))
        else:
            dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i+1][k]))

    for j in range(n-1):
        if(j < k):
            dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j]))
        else:
            dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j+1]))
deadlywing commented 1 week ago

猜测主要原因还是 .at[].set[] 这个语句会产生大量的ops;虽然comm. free,但是cpu time还是比较多的。

建议想想有没有等价的方式可以实现。。

zhou-pz commented 6 days ago

.at[].set[]确实很慢,而且在emulator里格外慢:

(1) 我用emulator跑,需要75s。(2) 不用emulator跑(纯python程序),只需要10s。(3) 不用emulator跑,且把jax.numpy的.at[].set[]操作全部替换为numpy的标准索引赋值,只需要0.03s。

zhou-pz commented 6 days ago

请问如果赋值操作无法避免的话,spu支持其它更高效的赋值操作吗?

deadlywing commented 6 days ago

只从你提供的代码来看,看上去是要从dist这个矩阵中每次提取一些数据出来组成新的矩阵,不知道有没有可能通过一些特殊的index/gather方法来得到:可以看看jax.numpy.choose, jax.lax.gather等方法