Open zhou-pz opened 1 week ago
hello,主要是因为如果那一行被comment,则编译器发现这个函数的return为常量0的array; 如果没被comment,才会按你定义的函数执行,至于为什么这么慢,主要是因为实现里的循环是O(n^2)的
hello,我把里面的两个循环注释掉了时间没减少,还是需要78秒(dist和batch_2只是100*100的)
for i in range(n-1): if(i < k): dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i][k])) else: dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i+1][k])) for j in range(n-1): if(j < k): dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j])) else: dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j+1]))
猜测主要原因还是 .at[].set[] 这个语句会产生大量的ops;虽然comm. free,但是cpu time还是比较多的。
建议想想有没有等价的方式可以实现。。
.at[].set[]确实很慢,而且在emulator里格外慢:
(1) 我用emulator跑,需要75s。(2) 不用emulator跑(纯python程序),只需要10s。(3) 不用emulator跑,且把jax.numpy的.at[].set[]操作全部替换为numpy的标准索引赋值,只需要0.03s。
请问如果赋值操作无法避免的话,spu支持其它更高效的赋值操作吗?
只从你提供的代码来看,看上去是要从dist这个矩阵中每次提取一些数据出来组成新的矩阵,不知道有没有可能通过一些特殊的index/gather方法来得到:可以看看jax.numpy.choose
, jax.lax.gather
等方法
Issue Type
Performance
Modules Involved
MPC protocol, SPU runtime
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu 0.9.3dev20241009
OS Platform and Distribution
Ubuntu 22.04
Python Version
3.9
Compiler Version
GCC 11.2.1
Current Behavior?
下面代码中floyd_opt函数里最后一句代码 “batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)” 导致程序跑得很慢,去掉后程序跑得很快,但是这是简单的赋值操作,不会带来这么大开销,是什么原因呢?输出日志和代码如下。
有batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)时输出日志:
去掉batch_2 = batch_2.at[indices].set(batch_2_upper_triangle)时输出日志:
def floyd_opt( dist ):
def emul_cpz(mode: emulation.Mode.MULTIPROCESS):
if name == "main": emul_cpz(emulation.Mode.MULTIPROCESS)
Relevant log output