Closed pguihc closed 7 months ago
请升级 jax 的版本后运行 bazel clean --expunge
后再重新运行,或者临时把 distributed.py 里的 jax.extend.linear_util
改成 jax.linear_util
我通过临时修改/root/.cache/bazel/_bazel_root/c7c2256833a99c4ceaf0534f480b1c44/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py
中的jax.extend.linear_util
成功运行了bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/puma_gpt2_benchmarks/3pc.json up
但是执行bazel run -c opt //examples/python/ml/puma_gpt2_benchmarks:puma_gpt2_benchmarks
时出现了下列问题
请问该如何解决
上面close不好意思点错了 我在临时修改了/root/.cache/bazel/_bazel_root/c7c2256833a99c4ceaf0534f480b1c44/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py后 成功运行bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/puma_gpt2_benchmarks/3pc.json up
但是在运行bazel run -c opt //examples/python/ml/puma_gpt2_benchmarks:puma_gpt2_benchmarks时出现了下列问题,请问如何解决
Hi @pguihc
puma_gpt2_benchmarks 是您新添加的吗
@Ye-D
@anakinxc 您好,这是我最近复现的操作 服务器环境如下:
#启动容器
sudo docker run -d -it --name spu-dev-$(whoami) \
--mount type=bind,source="$(pwd)",target=/home/admin/dev/ \
-w /home/admin/dev \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--cap-add=NET_ADMIN \
--privileged=true \
--entrypoint="bash" \
secretflow/ubuntu-base-ci:latest
sudo docker exec -it spu-dev-$(whoami) bash
#默认/home/admin/dev,克隆spu项目,把puma项目放入spu
git clone https://kkgithub.com/secretflow/spu.git
git clone https://hub.incept.pw/AntCPLab/puma_benchmarks.git
cp -r puma_benchmarks/puma_bert_benchmarks/ spu/examples/python/ml/
cp -r puma_benchmarks/puma_gpt2_benchmarks/ spu/examples/python/ml/
cd spu
#配置环境
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements-dev.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install transformers[flax] -i https://pypi.tuna.tsinghua.edu.cn/simple
##puma_gpt2_benchmarks需要导入的包
pip install datasets -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install evaluate -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple
根据之前的报错,参考其他issue把3pc.json中所有的89XX的端口改成了600XX,三个连续的60030改成了60030、60031、60032
#编译puma_gpt2_benchmarks
失败:bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/puma_gpt2_benchmarks/3pc.json up
于是我临时修改了/root/.cache/bazel/_bazel_root/c7c2256833a99c4ceaf0534f480b1c44/execroot/spulib/bazel-out/k8-opt/bin/examples/python/utils/nodectl.runfiles/spulib/spu/utils/distributed.py后再次编译
成功:bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/puma_gpt2_benchmarks/3pc.json up
然后开第二个窗口进入spu/跑第二条命令
bazel run -c opt //examples/python/ml/puma_gpt2_benchmarks:puma_gpt2_benchmarks
最后产生了这个报错
@Ye-D Mind take a look?
Hi, are you using the coding from here: PUMA-benchmark
@Ye-D yeah,i use git clone https://github.com/AntCPLab/puma_benchmarks.git
,
Here are my steps:
https://github.com/secretflow/spu/issues/583#issuecomment-1975863212
The bugs are related to the Array shape conversion of JAX, not related to SPU. Please check the dimensions of outputs.logits and related variables. And I will try to re-build this bugs asap and fix it.
@Ye-D OK, thank you. In addition, I would like to ask if PUMA-benchmark reproduction supports GPU. If not, when I choose puma_gpt2_benchmark for reproduction, how much RAM do I need at least?
it does not support GPU now. You need >= 64GB RAM for each ABY3 node.
@Ye-D Could you tell me the version of each imported package in Python when puma can be successfully reproduced?
It should be ok to follow the readme.md of SPU.
@Ye-D ok, thank you @anakinxc 请问自己通过编译安装spu包之后,像puma_bert_benchmarks.py需要import spu,是不是也需要参考https://github.com/secretflow/spu/issues/401#issuecomment-1809718108 的操作
@Ye-D ok, thank you @anakinxc 请问自己通过编译安装spu包之后,像puma_bert_benchmarks.py需要import spu,是不是也需要参考#401 (comment) 的操作
是的,不要在 spu 目录下 import spu
@anakinxc 好的,谢谢 @Ye-D Since the error seems to be related to the model, and the batch_num (= 100) here is a constant, is it possible that the failure of reshape is caused by the update of the FlaxGPT2LMHeadModel of huggingface?
@Ye-D Hi, I found that the reshape operation error here is because the default size of the gpt2 vacab_size imported in the code is 50257, which makes it impossible to divide by the batch_num 100 set in the code, so I don't think it is due to jax. Have you changed the vacab_size of the imported gpt2 model before running puma_gpt2_benchmarks, or is the size at that time not 50257?
@pguihc Probably, we do not need the reshape operation. Hope it helps you.
@Ye-D I follow 基于Puma框架的GPT2安全推理, hijacked the gelu function by modifying modeling_flax_gpt2.py [Change self.act to jax.nnn.gelu], but I don't seem to find softmax related functions in the modeling_flax_gpt2.py, is it not necessary to modify modeling_flax_gpt2.py to hijack the softmax function?
@pguihc We do not need modify the source code about softmax in modeling_flax_gpt2.py, and our hijack method can catch the softmax function. You can verify it by profiling the costs with that: 1) if you do not hijack softmax, we see about 12 calls of div; 2) when you enable the softmax hijack, you can see the div is replaced by recipe. Hope this helps you.
@Ye-D Hi, thank you very much for your reply. I would like to ask if this result is normal: I modified the flax_gpt2.py by hijacking the gelu and softmax functions. Then I ran the modified code and compared it with flax_gpt2.py, and found that they generated a token in about the same time.
Please check the profiling time and communication costs. If you run the code in localhost, then you might get similar running time.
@Ye-D This is part of my code and log. But the profile time in the log of both functions is about the same. Could you tell me what to do to know the impact of hijacking gelu and softmax on token generation time?
def run_on_spu():
# encode context the generation is conditioned on
inputs_ids = tokenizer.encode(
"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.\n\nI'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.\n\nI'm not", return_tensors='jax'
)
input_ids = ppd.device("P1")(lambda x: x)(inputs_ids)
params = ppd.device("P2")(lambda x: x)(pretrained_model.params)
outputs_ids = ppd.device("SPU")(text_generation,)(input_ids, params)
outputs_ids = ppd.get(outputs_ids)
return outputs_ids
def run_on_puma():
# encode context the generation is conditioned on
input_ids = tokenizer.encode(
"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.\n\nI'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.\n\nI'm not", return_tensors='jax'
)
with hack_softmax_context("hack exp of softmax", enabled=True), hack_gelu_context(
"hack gelu", enabled=True
):
params = ppd.device("P1")(lambda x: x)(pretrained_model.params)
input_ids = ppd.device("P2")(lambda x: x)(input_ids)
outputs_ids = ppd.device("SPU")(text_generation, copts=copts)(input_ids, params)
outputs_ids = ppd.get(outputs_ids)
return outputs_ids
if __name__ == '__main__':
print('\n------\nRun on SPU')
outputs_ids = run_on_spu()
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('\n------\nRun on PUMA')
outputs_ids = run_on_puma()
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
run_on_spu
[2024-03-25 13:02:37.348] [info] [thread_pool.cc:30] Create a fixed thread pool with size 7
[2024-03-25 13:05:31.695] [info] [api.cc:158] [Profiling] SPU execution text_generation completed, input processing took 0.000278298s, execution took 174.405490526s, output processing took 2.7001e-05s, total time 174.405795825s.
[2024-03-25 13:05:31.698] [info] [api.cc:191] HLO profiling: total time 0.0005930070000000002
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.add, executed 904 times, duration 6.4898e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.and, executed 22 times, duration 1.6e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.broadcast, executed 177 times, duration 1.31e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.concatenate, executed 1 times, duration 1e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.constant, executed 21 times, duration 4.8e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.convert, executed 6 times, duration 4e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.divide, executed 12 times, duration 6e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.dot, executed 49 times, duration 4.2e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.dot_general, executed 24 times, duration 1.9e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.dynamic_slice, executed 320 times, duration 1.5298e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.dynamic_update_slice, executed 160 times, duration 8.899e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.equal, executed 42 times, duration 3.8e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.exponential, executed 12 times, duration 5.01e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.free, executed 2265 times, duration 0.000351006s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.greater, executed 133 times, duration 9.701e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.iota, executed 3 times, duration 2e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.less, executed 186 times, duration 1.5399e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.multiply, executed 259 times, duration 2e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.not, executed 25 times, duration 2.3e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.or, executed 42 times, duration 3.4e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.reduce, executed 75 times, duration 4.8e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.reshape, executed 611 times, duration 3.7702e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.rsqrt, executed 25 times, duration 1.4e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.select, executed 130 times, duration 1.1902e-05s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.slice, executed 38 times, duration 2.599e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.subtract, executed 62 times, duration 4.6e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.tanh, executed 12 times, duration 8e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.transpose, executed 109 times, duration 6.902e-06s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pphlo.while, executed 2 times, duration 2e-07s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:191] HAL profiling: total time 173.93358367100004
[2024-03-25 13:05:31.698] [info] [api.cc:194] - _and, executed 64 times, duration 0.019365205s, send bytes 150851
[2024-03-25 13:05:31.698] [info] [api.cc:194] - _mux, executed 290 times, duration 1.253648811s, send bytes 34430800
[2024-03-25 13:05:31.698] [info] [api.cc:194] - _xor, executed 84 times, duration 0.00446952s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_add, executed 743 times, duration 2.770239202s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_div, executed 12 times, duration 7.521587431s, send bytes 608870400
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_equal, executed 42 times, duration 0.108638484s, send bytes 2327436
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_exp, executed 12 times, duration 7.961649317s, send bytes 608256000
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_less, executed 131 times, duration 2.90805605s, send bytes 34644132
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_mmul, executed 337 times, duration 91.965834275s, send bytes 415388800
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_mul, executed 234 times, duration 4.458886503s, send bytes 609134720
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_rsqrt, executed 25 times, duration 0.411716319s, send bytes 1214400
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_sub, executed 62 times, duration 0.526863849s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - f_tanh, executed 12 times, duration 16.010052817s, send bytes 1266155520
[2024-03-25 13:05:31.698] [info] [api.cc:194] - i_add, executed 161 times, duration 0.000603631s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - i_equal, executed 80 times, duration 1.763419343s, send bytes 80411200
[2024-03-25 13:05:31.698] [info] [api.cc:194] - i_less, executed 348 times, duration 0.827765763s, send bytes 1829412
[2024-03-25 13:05:31.698] [info] [api.cc:194] - int2fxp, executed 1 times, duration 0.000616487s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - logical_not, executed 25 times, duration 0.017349179s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mixed_mmul, executed 80 times, duration 35.372259685s, send bytes 42707400
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mixed_mul, executed 25 times, duration 0.025843804s, send bytes 48000
[2024-03-25 13:05:31.698] [info] [api.cc:194] - seal, executed 165 times, duration 0.004717996s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:191] MPC profiling: total time 172.160406216
[2024-03-25 13:05:31.698] [info] [api.cc:194] - a2b, executed 49 times, duration 2.366681761s, send bytes 206662400
[2024-03-25 13:05:31.698] [info] [api.cc:194] - add_aa, executed 1591 times, duration 3.992854534s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - add_ap, executed 1578 times, duration 1.436106414s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - add_pp, executed 887 times, duration 0.219816892s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - and_bb, executed 286 times, duration 0.277770448s, send bytes 40789251
[2024-03-25 13:05:31.698] [info] [api.cc:194] - and_bp, executed 510 times, duration 0.063622429s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - b2a, executed 305 times, duration 5.648715788s, send bytes 439475848
[2024-03-25 13:05:31.698] [info] [api.cc:194] - bitrev_b, executed 62 times, duration 0.205028027s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - broadcast, executed 275 times, duration 0.001367003s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - cast_type_b, executed 80 times, duration 0.074115073s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - common_type_b, executed 80 times, duration 0.000132293s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - concatenate, executed 117 times, duration 0.558113975s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - extract_slice, executed 3902 times, duration 0.016650643s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - lshift_b, executed 162 times, duration 0.010455848s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - lshift_p, executed 1 times, duration 0.000614188s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - make_p, executed 1896 times, duration 0.009247799s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mmul_aa, executed 598 times, duration 124.408608752s, send bytes 99011200
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mmul_ap, executed 516 times, duration 0.150444752s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - msb_a2b, executed 374 times, duration 15.087066604s, send bytes 348342984
[2024-03-25 13:05:31.698] [info] [api.cc:194] - msb_p, executed 165 times, duration 0.006847595s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mul_a1b, executed 331 times, duration 2.178083281s, send bytes 231287040
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mul_aa, executed 652 times, duration 4.286721805s, send bytes 453606032
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mul_ap, executed 417 times, duration 0.48041827s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - mul_pp, executed 26 times, duration 0.000254788s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - not_a, executed 685 times, duration 1.702592291s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - not_p, executed 533 times, duration 0.150410567s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - p2a, executed 85 times, duration 0.003604587s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - p2b, executed 80 times, duration 0.00050946s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - pad, executed 80 times, duration 0.076798846s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - reshape, executed 2280 times, duration 0.009890627s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - rshift_b, executed 618 times, duration 0.147587915s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - transpose, executed 1733 times, duration 0.004743211s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - trunc_a, executed 1046 times, duration 8.314012637s, send bytes 1803655680
[2024-03-25 13:05:31.698] [info] [api.cc:194] - update_slice, executed 160 times, duration 0.064519661s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - xor_bb, executed 965 times, duration 0.178504496s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:194] - xor_bp, executed 72 times, duration 0.027492956s, send bytes 0
[2024-03-25 13:05:31.698] [info] [api.cc:204] Link details: total send bytes 3705569071, send actions 9952
[2024-03-25 13:05:31,735] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
[2024-03-25 13:05:31,757] [ForkServerProcess-1] RunR: builtin_gc at node:0
[2024-03-25 13:05:31,763] [ForkServerProcess-2] RunR: builtin_gc at node:1
run_on_puma
[2024-03-25 13:06:13,149] [ForkServerProcess-4] RunR: builtin_fetch_object at node:3
[2024-03-25 13:06:13,271] [ForkServerProcess-4] RunR: builtin_fetch_object at node:3
[2024-03-25 13:09:21.502] [info] [api.cc:158] [Profiling] SPU execution text_generation completed, input processing took 0.000600511s, execution took 175.1559796s, output processing took 1.8585e-05s, total time 175.156598696s.
[2024-03-25 13:09:21.506] [info] [api.cc:191] HLO profiling: total time 0.0004551820000000001
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.add, executed 988 times, duration 6.9699e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.and, executed 22 times, duration 1.897e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.broadcast, executed 177 times, duration 1.3999e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.concatenate, executed 1 times, duration 1e-07s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.constant, executed 31 times, duration 2.4e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.convert, executed 6 times, duration 2e-07s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.dot, executed 49 times, duration 5.3e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.dot_general, executed 24 times, duration 2.1e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.dynamic_slice, executed 320 times, duration 2.71e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.dynamic_update_slice, executed 160 times, duration 1.12e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.equal, executed 42 times, duration 3.496e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.exponential, executed 12 times, duration 1.1e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.free, executed 2551 times, duration 0.000177293s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.greater, executed 157 times, duration 9.999e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.iota, executed 3 times, duration 2e-07s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.less, executed 210 times, duration 1.79e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.multiply, executed 379 times, duration 3.04e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.not, executed 25 times, duration 2.799e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.or, executed 42 times, duration 3.296e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.reciprocal, executed 12 times, duration 8e-07s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.reduce, executed 75 times, duration 6e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.reshape, executed 611 times, duration 3.8503e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.rsqrt, executed 25 times, duration 1e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.select, executed 130 times, duration 1.0901e-05s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.slice, executed 38 times, duration 2e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.subtract, executed 62 times, duration 5e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.transpose, executed 109 times, duration 7.7e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.while, executed 2 times, duration 3e-07s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pphlo.xor, executed 36 times, duration 2.5e-06s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:191] HAL profiling: total time 174.62620178900005
[2024-03-25 13:09:21.506] [info] [api.cc:194] - _and, executed 64 times, duration 0.024339881s, send bytes 150851
[2024-03-25 13:09:21.506] [info] [api.cc:194] - _mux, executed 290 times, duration 1.126962648s, send bytes 31088872
[2024-03-25 13:09:21.506] [info] [api.cc:194] - _xor, executed 120 times, duration 0.081555781s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_add, executed 827 times, duration 3.010743337s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_equal, executed 42 times, duration 0.120436713s, send bytes 2254396
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_exp, executed 12 times, duration 7.366235308s, send bytes 592896000
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_less, executed 179 times, duration 14.7536218s, send bytes 386326692
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_mmul, executed 337 times, duration 104.164790788s, send bytes 340963200
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_mul, executed 306 times, duration 5.59865876s, send bytes 984654720
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_reciprocal, executed 12 times, duration 0.299107627s, send bytes 7188480
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_rsqrt, executed 25 times, duration 0.351868996s, send bytes 1217600
[2024-03-25 13:09:21.506] [info] [api.cc:194] - f_sub, executed 62 times, duration 0.481107559s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - i_add, executed 161 times, duration 0.000468001s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - i_equal, executed 80 times, duration 1.663672927s, send bytes 85637928
[2024-03-25 13:09:21.506] [info] [api.cc:194] - i_less, executed 348 times, duration 1.363537817s, send bytes 1829412
[2024-03-25 13:09:21.506] [info] [api.cc:194] - int2fxp, executed 1 times, duration 0.00078261s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - logical_not, executed 25 times, duration 0.018210554s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mixed_mmul, executed 80 times, duration 32.45519391s, send bytes 44315624
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mixed_mul, executed 73 times, duration 1.741728939s, send bytes 234503040
[2024-03-25 13:09:21.506] [info] [api.cc:194] - seal, executed 165 times, duration 0.003177833s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:191] MPC profiling: total time 173.05019504800003
[2024-03-25 13:09:21.506] [info] [api.cc:194] - a2b, executed 49 times, duration 1.106317642s, send bytes 104733440
[2024-03-25 13:09:21.506] [info] [api.cc:194] - add_aa, executed 1567 times, duration 3.63389891s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - add_ap, executed 1482 times, duration 0.833936012s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - add_pp, executed 851 times, duration 0.140068137s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - and_bb, executed 286 times, duration 0.076723369s, send bytes 745731
[2024-03-25 13:09:21.506] [info] [api.cc:194] - and_bp, executed 510 times, duration 0.062551234s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - b2a, executed 305 times, duration 3.083967814s, send bytes 212071104
[2024-03-25 13:09:21.506] [info] [api.cc:194] - bitrev_b, executed 62 times, duration 0.004739038s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - broadcast, executed 285 times, duration 0.001233996s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - cast_type_b, executed 80 times, duration 0.068575986s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - common_type_b, executed 80 times, duration 0.0001106s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - concatenate, executed 105 times, duration 0.053056988s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - extract_slice, executed 2870 times, duration 0.013249695s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - lshift_b, executed 162 times, duration 0.009771627s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - lshift_p, executed 1 times, duration 0.00077971s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - make_p, executed 1836 times, duration 0.00725682s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mmul_aa, executed 598 times, duration 133.177800605s, send bytes 99011200
[2024-03-25 13:09:21.506] [info] [api.cc:194] - msb_a2b, executed 398 times, duration 17.87998366s, send bytes 454926024
[2024-03-25 13:09:21.506] [info] [api.cc:194] - msb_p, executed 165 times, duration 0.004695157s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mul_a1b, executed 355 times, duration 2.653775472s, send bytes 324184320
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mul_aa, executed 556 times, duration 2.144158013s, send bytes 206617232
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mul_ap, executed 441 times, duration 0.473338511s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - mul_pp, executed 26 times, duration 0.000280499s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - not_a, executed 613 times, duration 0.960996483s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - not_p, executed 497 times, duration 0.100604894s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - p2a, executed 85 times, duration 0.002137337s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - p2b, executed 80 times, duration 0.000489899s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - pad, executed 80 times, duration 0.069567138s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - reshape, executed 2256 times, duration 0.006727151s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - rshift_b, executed 618 times, duration 0.076920945s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - transpose, executed 185 times, duration 0.058945252s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - trunc_a, executed 974 times, duration 6.171933394s, send bytes 1222845440
[2024-03-25 13:09:21.506] [info] [api.cc:194] - update_slice, executed 160 times, duration 0.071543634s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - xor_bb, executed 989 times, duration 0.065685326s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:194] - xor_bp, executed 84 times, duration 0.0343741s, send bytes 0
[2024-03-25 13:09:21.506] [info] [api.cc:204] Link details: total send bytes 2713026815, send actions 10006
[2024-03-25 13:09:21,569] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
[2024-03-25 13:09:21,595] [ForkServerProcess-1] RunR: builtin_gc at node:0
[2024-03-25 13:09:21,599] [ForkServerProcess-2] RunR: builtin_gc at node:1
you can see that the mmul and mixed-mmul consume most of the running time, as puma and spu are same in these implementation, the running time are at the same level at your setup settings.
@Ye-D Hi, Thank you very much for your reply. Do you mean that the current spu has optimized the internal related calculation function, so that the time comparison of token generation after only hijacking softmax and gelu is not obvious?
the difference of communication is obvious, so you should do the experiments with a WAN setting.
Best wishes,
Ye Dong https://ye-d.github.io/
pguihc @.***> 于2024年3月26日周二 19:06写道:
@Ye-D https://github.com/Ye-D Hi, Thank you very much for your reply. Do you mean that the current spu has optimized the internal related calculation function, so that the time comparison of token generation after only hijacking softmax and gelu is not obvious?
— Reply to this email directly, view it on GitHub https://github.com/secretflow/spu/issues/583#issuecomment-2020133574, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFU7RG2QMRF3EZXMTI3DHK3Y2FJCLAVCNFSM6AAAAABEDF7FKKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRQGEZTGNJXGQ . You are receiving this because you were mentioned.Message ID: @.***>
@Ye-D Hi, Thank you so much for your continued replies. Following USENIX ATC '23 Artifact Evaluation, I deployed SPU on two docker containers and ran PUMA_GPT2. Due to the low performance of the computer, I modified 3pc.json to replace the ABY3 protocol with SEMI2K. And the bandwidth limits used are as follows:
tc qdisc del dev eth0 root
tc qdisc add dev eth0 root handle 1: tbf rate 300mbit burst 256kb latency 800ms
tc qdisc add dev eth0 parent 1:1 handle 10: netem delay 20msec limit 8000
Here are some of the logs I got. I found that the running time under WAN SETTING is stable, and the time difference between SPU and PUMA is obvious. The following tasks are all to generate one token. run on spu:
HAL profiling: total time 443.70149710000004
- f_mmul, executed 337 times, duration 228.7568215s, send bytes 2854566584 recv bytes 2854566584
- mixed_mmul, executed 7 times, duration 111.0405614s, send bytes 2167081840 recv bytes 2167081840
- f_tanh, executed 12 times, duration 17.9802313s, send bytes 101154816 recv bytes 101154816
- f_rsqrt, executed 25 times, duration 16.8606845s, send bytes 82600 recv bytes 82600
- f_less, executed 95 times, duration 15.2914202s, send bytes 3162880 recv bytes 3162880
- f_div, executed 12 times, duration 10.8492046s, send bytes 4064256 recv bytes 4064256
- f_mul, executed 234 times, duration 10.0562511s, send bytes 34972504 recv bytes 34972504
- i_less, executed 56 times, duration 9.3426555s, send bytes 2817136 recv bytes 2817136
- f_equal, executed 42 times, duration 7.1363435s, send bytes 10453456 recv bytes 10453456
- _mux, executed 108 times, duration 7.1015625s, send bytes 2687352 recv bytes 2687352
- i_equal, executed 7 times, duration 3.3231174s, send bytes 36587096 recv bytes 36587096
- f_exp, executed 12 times, duration 3.0741087s, send bytes 903168 recv bytes 903168
- _and, executed 64 times, duration 1.4644008s, send bytes 301556 recv bytes 301556
- mixed_mul, executed 25 times, duration 1.1255127s, send bytes 4200 recv bytes 4200
- f_add, executed 707 times, duration 0.2250496s, send bytes 0 recv bytes 0
- f_sub, executed 62 times, duration 0.0435011s, send bytes 0 recv bytes 0
- logical_not, executed 25 times, duration 0.0194738s, send bytes 0 recv bytes 0
- seal, executed 19 times, duration 0.0067897s, send bytes 0 recv bytes 0
- _xor, executed 84 times, duration 0.0037242s, send bytes 0 recv bytes 0
- i_add, executed 15 times, duration 7.44e-05s, send bytes 0 recv bytes 0
- int2fxp, executed 1 times, duration 8.6e-06s, send bytes 0 recv bytes 0
MPC profiling: total time 443.6329827
- mmul_aa, executed 413 times, duration 317.7271854s, send bytes 5010802232 recv bytes 5010802232
- msb_a2b, executed 168 times, duration 34.4633s, send bytes 35276528 recv bytes 35276528
- trunc_a, executed 1034 times, duration 32.6285831s, send bytes 47615848 recv bytes 47615848
- mul_aa, executed 717 times, duration 17.2310869s, send bytes 67648464 recv bytes 67648464
- b2a, executed 280 times, duration 16.9884604s, send bytes 8013888 recv bytes 8013888
- equal_ss, executed 42 times, duration 7.1359662s, send bytes 10453456 recv bytes 10453456
- and_bb, executed 286 times, duration 6.5183434s, send bytes 937884 recv bytes 937884
- a2b, executed 37 times, duration 6.2237153s, send bytes 1504048 recv bytes 1504048
- equal_sp, executed 7 times, duration 3.3230561s, send bytes 36587096 recv bytes 36587096
- add_aa, executed 1225 times, duration 0.3206567s, send bytes 0 recv bytes 0
- not_a, executed 444 times, duration 0.2628317s, send bytes 0 recv bytes 0
- add_pp, executed 194 times, duration 0.1867729s, send bytes 0 recv bytes 0
- add_ap, executed 792 times, duration 0.173326s, send bytes 0 recv bytes 0
- not_p, executed 132 times, duration 0.13884s, send bytes 0 recv bytes 0
- concatenate, executed 44 times, duration 0.0987306s, send bytes 0 recv bytes 0
- mul_ap, executed 285 times, duration 0.0898253s, send bytes 0 recv bytes 0
- extract_slice, executed 2077 times, duration 0.0280824s, send bytes 0 recv bytes 0
- make_p, executed 1110 times, duration 0.0221769s, send bytes 0 recv bytes 0
- pad, executed 7 times, duration 0.0127266s, send bytes 0 recv bytes 0
- msb_p, executed 19 times, duration 0.0108503s, send bytes 0 recv bytes 0
- xor_bb, executed 965 times, duration 0.0107985s, send bytes 0 recv bytes 0
- mmul_ap, executed 48 times, duration 0.0094578s, send bytes 0 recv bytes 0
- p2a, executed 12 times, duration 0.0064864s, send bytes 0 recv bytes 0
- reshape, executed 1550 times, duration 0.0064779s, send bytes 0 recv bytes 0
- rshift_b, executed 534 times, duration 0.0030746s, send bytes 0 recv bytes 0
- broadcast, executed 202 times, duration 0.0026712s, send bytes 0 recv bytes 0
- update_slice, executed 14 times, duration 0.0021295s, send bytes 0 recv bytes 0
- xor_bp, executed 60 times, duration 0.0019978s, send bytes 0 recv bytes 0
- bitrev_b, executed 62 times, duration 0.0015003s, send bytes 0 recv bytes 0
- and_bp, executed 450 times, duration 0.0014981s, send bytes 0 recv bytes 0
- transpose, executed 329 times, duration 0.0014867s, send bytes 0 recv bytes 0
- lshift_b, executed 150 times, duration 0.0004722s, send bytes 0 recv bytes 0
- p2b, executed 7 times, duration 0.0002253s, send bytes 0 recv bytes 0
- mul_pp, executed 26 times, duration 0.0001847s, send bytes 0 recv bytes 0
- lshift_p, executed 1 times, duration 5.5e-06s, send bytes 0 recv bytes 0
Link details: total send bytes 5218839444, recv bytes 5218839444, send actions 4508
builtin_fetch_meta at node:0
run on puma:
HAL profiling: total time 103.38578369999999
- f_mmul, executed 337 times, duration 64.1608951s, send bytes 2854566584 recv bytes 2854566584
- mixed_mmul, executed 7 times, duration 29.109483s, send bytes 2167081840 recv bytes 2167081840
- f_tanh, executed 12 times, duration 3.9648723s, send bytes 101154816 recv bytes 101154816
- f_mul, executed 234 times, duration 1.1794851s, send bytes 34972504 recv bytes 34972504
- i_less, executed 56 times, duration 1.0440972s, send bytes 2817136 recv bytes 2817136
- f_less, executed 107 times, duration 0.847822s, send bytes 3558016 recv bytes 3558016
- _mux, executed 108 times, duration 0.7910354s, send bytes 2687352 recv bytes 2687352
- i_equal, executed 7 times, duration 0.6048745s, send bytes 36587096 recv bytes 36587096
- f_div, executed 12 times, duration 0.4905115s, send bytes 4064256 recv bytes 4064256
- f_rsqrt, executed 25 times, duration 0.4837378s, send bytes 82600 recv bytes 82600
- f_equal, executed 42 times, duration 0.3145577s, send bytes 10453456 recv bytes 10453456
- f_add, executed 707 times, duration 0.1307882s, send bytes 0 recv bytes 0
- f_exp, executed 12 times, duration 0.0876748s, send bytes 903168 recv bytes 903168
- mixed_mul, executed 37 times, duration 0.0672976s, send bytes 173544 recv bytes 173544
- _and, executed 64 times, duration 0.0540071s, send bytes 301556 recv bytes 301556
- f_sub, executed 62 times, duration 0.0288207s, send bytes 0 recv bytes 0
- logical_not, executed 25 times, duration 0.0187574s, send bytes 0 recv bytes 0
- _xor, executed 84 times, duration 0.0037669s, send bytes 0 recv bytes 0
- seal, executed 19 times, duration 0.0032152s, send bytes 0 recv bytes 0
- i_add, executed 15 times, duration 7.58e-05s, send bytes 0 recv bytes 0
- int2fxp, executed 1 times, duration 8.4e-06s, send bytes 0 recv bytes 0
MPC profiling: total time 103.3568216
- mmul_aa, executed 413 times, duration 88.8471574s, send bytes 5010802232 recv bytes 5010802232
- b2a, executed 292 times, duration 4.2603825s, send bytes 8070336 recv bytes 8070336
- msb_a2b, executed 180 times, duration 4.0227752s, send bytes 35671664 recv bytes 35671664
- trunc_a, executed 1034 times, duration 2.7773329s, send bytes 47615848 recv bytes 47615848
- mul_aa, executed 729 times, duration 1.426478s, send bytes 67761360 recv bytes 67761360
- equal_sp, executed 7 times, duration 0.6044354s, send bytes 36587096 recv bytes 36587096
- equal_ss, executed 42 times, duration 0.3142665s, send bytes 10453456 recv bytes 10453456
- and_bb, executed 286 times, duration 0.2071461s, send bytes 937884 recv bytes 937884
- a2b, executed 37 times, duration 0.1926027s, send bytes 1504048 recv bytes 1504048
- add_aa, executed 1225 times, duration 0.1651116s, send bytes 0 recv bytes 0
- not_a, executed 456 times, duration 0.126297s, send bytes 0 recv bytes 0
- add_pp, executed 194 times, duration 0.1226415s, send bytes 0 recv bytes 0
- not_p, executed 132 times, duration 0.080359s, send bytes 0 recv bytes 0
- add_ap, executed 816 times, duration 0.0779336s, send bytes 0 recv bytes 0
- mul_ap, executed 285 times, duration 0.0305726s, send bytes 0 recv bytes 0
- concatenate, executed 44 times, duration 0.0297125s, send bytes 0 recv bytes 0
- extract_slice, executed 2077 times, duration 0.0176818s, send bytes 0 recv bytes 0
- xor_bb, executed 965 times, duration 0.0109251s, send bytes 0 recv bytes 0
- pad, executed 7 times, duration 0.010752s, send bytes 0 recv bytes 0
- reshape, executed 1550 times, duration 0.0050017s, send bytes 0 recv bytes 0
- make_p, executed 1122 times, duration 0.0043944s, send bytes 0 recv bytes 0
- mmul_ap, executed 48 times, duration 0.0043419s, send bytes 0 recv bytes 0
- msb_p, executed 19 times, duration 0.0036313s, send bytes 0 recv bytes 0
- p2a, executed 12 times, duration 0.0028051s, send bytes 0 recv bytes 0
- rshift_b, executed 534 times, duration 0.0027121s, send bytes 0 recv bytes 0
- bitrev_b, executed 62 times, duration 0.0021476s, send bytes 0 recv bytes 0
- transpose, executed 329 times, duration 0.0017387s, send bytes 0 recv bytes 0
- and_bp, executed 450 times, duration 0.0015348s, send bytes 0 recv bytes 0
- xor_bp, executed 60 times, duration 0.0013668s, send bytes 0 recv bytes 0
- update_slice, executed 14 times, duration 0.0009198s, send bytes 0 recv bytes 0
- broadcast, executed 203 times, duration 0.0006957s, send bytes 0 recv bytes 0
- lshift_b, executed 150 times, duration 0.0005044s, send bytes 0 recv bytes 0
- p2b, executed 7 times, duration 0.0003276s, send bytes 0 recv bytes 0
- mul_pp, executed 26 times, duration 0.0001304s, send bytes 0 recv bytes 0
- lshift_p, executed 1 times, duration 5.9e-06s, send bytes 0 recv bytes 0
Link details: total send bytes 5219403924, recv bytes 5219403924, send actions 4616
Finally, I would like to ask, if I want to explore the impact of different activation functions and softmax functions on efficiency, which parameters should I refer to first (e.g. HAL profiling or MPC profiling; mmul_aa or mmul_ap, etc.)?
Basically, it would be difficult to see the cost of a specific function since it has been broken down into smaller operations.
But your latest puma's run seems weired. It should not contain f_div
@fionser @Ye-D
in puma_gpt2.py:copts.enable_optimize_denominator_with_broadcast = True
in both docker containers, modeling_flax_gpt2.py:
But there's still f_div
in the log
What does it mean to have f _ div
in the log?
@fionser @Ye-D in puma_gpt2.py:
copts.enable_optimize_denominator_with_broadcast = True
in both docker containers, modeling_flax_gpt2.py: But there's stillf_div
in the log What does it mean to havef _ div
in the log?
In your old log, no f_div
is called. It correctly called f_reciprocal
.
But in the latter log, f_div
pops out.
@Ye-D @fionser Thank you for your patient reply. After checking the code, I found that I was missing copts.
with hack_softmax_context("hijack jax softmax", enabled = True), hack_gelu_context("hijack jax gelu", enabled=True):
input_ids = ppd.device("P1")(lambda x: x)(inputs_ids)
params = ppd.device("P2")(lambda x: x)(pretrained_model.params)
# error: outputs_ids = ppd.device("SPU")(text_generation,)(input_ids, params)
outputs_ids = ppd.device("SPU")(text_generation, copts=copts)(input_ids, params)
outputs_ids = ppd.get(outputs_ids)
After modifying the code, the log is as follows: ABY3+PUMA+generate one token:
[09:20:21,732] [MainProcess] Run : builtin_spu_run at node:0
[09:20:50.723] [info] [thread_pool.cc:30] Create a fixed thread pool with size 7
[09:23:18.501] [info] [api.cc:163] [Profiling] SPU execution text_generation completed, input processing took 7.4e-05s, execution took 148.4880594s, output processing took 4.1e-06s, total time 148.4881375s.
[09:23:18.514] [info] [api.cc:209] HLO profiling: total time 0.0013653000000000003
[09:23:18.521] [info] [api.cc:212] - pphlo.constant, executed 31 times, duration 0.0008638s, send bytes 0 recv bytes 0
[09:23:18.523] [info] [api.cc:212] - pphlo.free, executed 1931 times, duration 0.0002309s, send bytes 0 recv bytes 0
[09:23:18.525] [info] [api.cc:212] - pphlo.add, executed 806 times, duration 8.27e-05s, send bytes 0 recv bytes 0
[09:23:18.527] [info] [api.cc:212] - pphlo.multiply, executed 379 times, duration 4.22e-05s, send bytes 0 recv bytes 0
[09:23:18.528] [info] [api.cc:212] - pphlo.reshape, executed 319 times, duration 3.11e-05s, send bytes 0 recv bytes 0
[09:23:18.529] [info] [api.cc:212] - pphlo.broadcast, executed 177 times, duration 1.76e-05s, send bytes 0 recv bytes 0
[09:23:18.529] [info] [api.cc:212] - pphlo.greater, executed 121 times, duration 1.38e-05s, send bytes 0 recv bytes 0
[09:23:18.529] [info] [api.cc:212] - pphlo.select, executed 94 times, duration 1.14e-05s, send bytes 0 recv bytes 0
[09:23:18.530] [info] [api.cc:212] - pphlo.transpose, executed 109 times, duration 1.08e-05s, send bytes 0 recv bytes 0
[09:23:18.530] [info] [api.cc:212] - pphlo.reduce, executed 75 times, duration 8.3e-06s, send bytes 0 recv bytes 0
[09:23:18.531] [info] [api.cc:212] - pphlo.less, executed 64 times, duration 7.6e-06s, send bytes 0 recv bytes 0
[09:23:18.532] [info] [api.cc:212] - pphlo.equal, executed 42 times, duration 5.5e-06s, send bytes 0 recv bytes 0
[09:23:18.532] [info] [api.cc:212] - pphlo.subtract, executed 62 times, duration 5.5e-06s, send bytes 0 recv bytes 0
[09:23:18.537] [info] [api.cc:212] - pphlo.xor, executed 36 times, duration 4.8e-06s, send bytes 0 recv bytes 0
[09:23:18.538] [info] [api.cc:212] - pphlo.dot, executed 49 times, duration 4.6e-06s, send bytes 0 recv bytes 0
[09:23:18.539] [info] [api.cc:212] - pphlo.slice, executed 38 times, duration 3.9e-06s, send bytes 0 recv bytes 0
[09:23:18.539] [info] [api.cc:212] - pphlo.or, executed 42 times, duration 3.9e-06s, send bytes 0 recv bytes 0
[09:23:18.539] [info] [api.cc:212] - pphlo.rsqrt, executed 25 times, duration 2.7e-06s, send bytes 0 recv bytes 0
[09:23:18.540] [info] [api.cc:212] - pphlo.not, executed 25 times, duration 2.7e-06s, send bytes 0 recv bytes 0
[09:23:18.541] [info] [api.cc:212] - pphlo.and, executed 22 times, duration 2.4e-06s, send bytes 0 recv bytes 0
[09:23:18.541] [info] [api.cc:212] - pphlo.dynamic_slice, executed 28 times, duration 2.2e-06s, send bytes 0 recv bytes 0
[09:23:18.543] [info] [api.cc:212] - pphlo.dot_general, executed 24 times, duration 2.2e-06s, send bytes 0 recv bytes 0
[09:23:18.543] [info] [api.cc:212] - pphlo.dynamic_update_slice, executed 14 times, duration 1.5e-06s, send bytes 0 recv bytes 0
[09:23:18.544] [info] [api.cc:212] - pphlo.reciprocal, executed 12 times, duration 1.1e-06s, send bytes 0 recv bytes 0
[09:23:18.546] [info] [api.cc:212] - pphlo.exponential, executed 12 times, duration 9e-07s, send bytes 0 recv bytes 0
[09:23:18.547] [info] [api.cc:212] - pphlo.convert, executed 6 times, duration 6e-07s, send bytes 0 recv bytes 0
[09:23:18.547] [info] [api.cc:212] - pphlo.iota, executed 3 times, duration 3e-07s, send bytes 0 recv bytes 0
[09:23:18.548] [info] [api.cc:212] - pphlo.while, executed 2 times, duration 2e-07s, send bytes 0 recv bytes 0
[09:23:18.549] [info] [api.cc:212] - pphlo.concatenate, executed 1 times, duration 1e-07s, send bytes 0 recv bytes 0
[09:23:18.551] [info] [api.cc:209] HAL profiling: total time 147.88829149999995
[09:23:18.552] [info] [api.cc:212] - f_mmul, executed 337 times, duration 117.4234551s, send bytes 27935320 recv bytes 32479160
[09:23:18.562] [info] [api.cc:212] - f_less, executed 143 times, duration 9.1182813s, send bytes 30156480 recv bytes 30156480
[09:23:18.563] [info] [api.cc:212] - mixed_mmul, executed 7 times, duration 6.6586203s, send bytes 3259456 recv bytes 6475904
[09:23:18.567] [info] [api.cc:212] - f_mul, executed 306 times, duration 3.4422003s, send bytes 84634424 recv bytes 77438648
[09:23:18.569] [info] [api.cc:212] - f_rsqrt, executed 25 times, duration 1.9667374s, send bytes 105252 recv bytes 109284
[09:23:18.570] [info] [api.cc:212] - f_exp, executed 12 times, duration 1.7966096s, send bytes 4619328 recv bytes 4299456
[09:23:18.570] [info] [api.cc:212] - i_less, executed 56 times, duration 1.6316855s, send bytes 1811016 recv bytes 1811016
[09:23:18.571] [info] [api.cc:212] - mixed_mul, executed 73 times, duration 1.4201196s, send bytes 18753000 recv bytes 6251000
[09:23:18.573] [info] [api.cc:212] - f_reciprocal, executed 12 times, duration 1.1650202s, send bytes 696864 recv bytes 717024
[09:23:18.573] [info] [api.cc:212] - f_add, executed 791 times, duration 1.0291372s, send bytes 0 recv bytes 0
[09:23:18.574] [info] [api.cc:212] - _mux, executed 108 times, duration 0.6971559s, send bytes 7555344 recv bytes 5422632
[09:23:18.577] [info] [api.cc:212] - i_equal, executed 7 times, duration 0.5983752s, send bytes 6734438 recv bytes 8342662
[09:23:18.578] [info] [api.cc:212] - f_equal, executed 42 times, duration 0.5677444s, send bytes 2290652 recv bytes 1436508
[09:23:18.578] [info] [api.cc:212] - _and, executed 64 times, duration 0.1560499s, send bytes 150778 recv bytes 150778
[09:23:18.579] [info] [api.cc:212] - f_sub, executed 62 times, duration 0.0981315s, send bytes 0 recv bytes 0
[09:23:18.582] [info] [api.cc:212] - logical_not, executed 25 times, duration 0.0650129s, send bytes 0 recv bytes 0
[09:23:18.584] [info] [api.cc:212] - _xor, executed 120 times, duration 0.0530377s, send bytes 0 recv bytes 0
[09:23:18.586] [info] [api.cc:212] - seal, executed 19 times, duration 0.0008418s, send bytes 0 recv bytes 0
[09:23:18.587] [info] [api.cc:212] - i_add, executed 15 times, duration 6.64e-05s, send bytes 0 recv bytes 0
[09:23:18.587] [info] [api.cc:212] - int2fxp, executed 1 times, duration 9.3e-06s, send bytes 0 recv bytes 0
[09:23:18.589] [info] [api.cc:209] MPC profiling: total time 147.82688670000002
[09:23:18.590] [info] [api.cc:212] - mmul_aa, executed 434 times, duration 122.6240956s, send bytes 8074808 recv bytes 8074808
[09:23:18.591] [info] [api.cc:212] - msb_a2b, executed 216 times, duration 10.7372751s, send bytes 32511816 recv bytes 32511816
[09:23:18.592] [info] [api.cc:212] - trunc_a, executed 998 times, duration 4.082622s, send bytes 94263232 recv bytes 91568064
[09:23:18.593] [info] [api.cc:212] - mul_aa, executed 580 times, duration 2.1596241s, send bytes 12942168 recv bytes 12942168
[09:23:18.594] [info] [api.cc:212] - mul_a1b, executed 173 times, duration 1.9367161s, send bytes 19536048 recv bytes 6512016
[09:23:18.595] [info] [api.cc:212] - b2a, executed 232 times, duration 1.8543133s, send bytes 11223592 recv bytes 12576912
[09:23:18.596] [info] [api.cc:212] - add_aa, executed 1350 times, duration 0.9763371s, send bytes 0 recv bytes 0
[09:23:18.596] [info] [api.cc:212] - a2b, executed 49 times, duration 0.6459768s, send bytes 922768 recv bytes 922768
[09:23:18.597] [info] [api.cc:212] - equal_sp, executed 7 times, duration 0.5979153s, send bytes 6734438 recv bytes 8342662
[09:23:18.599] [info] [api.cc:212] - equal_ss, executed 42 times, duration 0.567353s, send bytes 2290652 recv bytes 1436508
[09:23:18.601] [info] [api.cc:212] - and_bb, executed 286 times, duration 0.4620119s, send bytes 202830 recv bytes 202830
[09:23:18.602] [info] [api.cc:212] - add_pp, executed 194 times, duration 0.2800023s, send bytes 0 recv bytes 0
[09:23:18.603] [info] [api.cc:212] - mul_ap, executed 441 times, duration 0.2071876s, send bytes 0 recv bytes 0
[09:23:18.610] [info] [api.cc:212] - not_p, executed 132 times, duration 0.2032027s, send bytes 0 recv bytes 0
[09:23:18.611] [info] [api.cc:212] - add_ap, executed 984 times, duration 0.1868088s, send bytes 0 recv bytes 0
[09:23:18.613] [info] [api.cc:212] - not_a, executed 468 times, duration 0.1229786s, send bytes 0 recv bytes 0
[09:23:18.614] [info] [api.cc:212] - xor_bb, executed 989 times, duration 0.0377837s, send bytes 0 recv bytes 0
[09:23:18.616] [info] [api.cc:212] - extract_slice, executed 2009 times, duration 0.0271404s, send bytes 0 recv bytes 0
[09:23:18.617] [info] [api.cc:212] - xor_bp, executed 84 times, duration 0.0210546s, send bytes 0 recv bytes 0
[09:23:18.617] [info] [api.cc:212] - cast_type_b, executed 7 times, duration 0.0169029s, send bytes 0 recv bytes 0
[09:23:18.618] [info] [api.cc:212] - pad, executed 7 times, duration 0.0168406s, send bytes 0 recv bytes 0
[09:23:18.619] [info] [api.cc:212] - update_slice, executed 14 times, duration 0.0117053s, send bytes 0 recv bytes 0
[09:23:18.620] [info] [api.cc:212] - msb_p, executed 19 times, duration 0.0106217s, send bytes 0 recv bytes 0
[09:23:18.621] [info] [api.cc:212] - reshape, executed 1526 times, duration 0.0096493s, send bytes 0 recv bytes 0
[09:23:18.621] [info] [api.cc:212] - broadcast, executed 212 times, duration 0.008624s, send bytes 0 recv bytes 0
[09:23:18.622] [info] [api.cc:212] - make_p, executed 1326 times, duration 0.0075926s, send bytes 0 recv bytes 0
[09:23:18.624] [info] [api.cc:212] - and_bp, executed 510 times, duration 0.0041056s, send bytes 0 recv bytes 0
[09:23:18.625] [info] [api.cc:212] - rshift_b, executed 618 times, duration 0.0029436s, send bytes 0 recv bytes 0
[09:23:18.626] [info] [api.cc:212] - transpose, executed 185 times, duration 0.0025271s, send bytes 0 recv bytes 0
[09:23:18.629] [info] [api.cc:212] - concatenate, executed 32 times, duration 0.0024656s, send bytes 0 recv bytes 0
[09:23:18.629] [info] [api.cc:212] - bitrev_b, executed 62 times, duration 0.0009247s, send bytes 0 recv bytes 0
[09:23:18.632] [info] [api.cc:212] - p2a, executed 12 times, duration 0.000713s, send bytes 0 recv bytes 0
[09:23:18.635] [info] [api.cc:212] - lshift_b, executed 162 times, duration 0.0005165s, send bytes 0 recv bytes 0
[09:23:18.636] [info] [api.cc:212] - mul_pp, executed 26 times, duration 0.0002821s, send bytes 0 recv bytes 0
[09:23:18.645] [info] [api.cc:212] - p2b, executed 7 times, duration 5.01e-05s, send bytes 0 recv bytes 0
[09:23:18.648] [info] [api.cc:212] - common_type_b, executed 7 times, duration 1.61e-05s, send bytes 0 recv bytes 0
[09:23:18.649] [info] [api.cc:212] - lshift_p, executed 1 times, duration 6.9e-06s, send bytes 0 recv bytes 0
[09:23:18.650] [info] [api.cc:222] Link details: total send bytes 188702352, recv bytes 175090552, send actions 7259
[09:23:18,983] [MainProcess] RunR: builtin_fetch_meta at node:0
Issue Type
Build/Install
Modules Involved
SPU runtime
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
from latest source code
OS Platform and Distribution
Linux Ubuntu 22.04
Python Version
3.10
Compiler Version
GCC 11.4.0
Current Behavior?
我尝试复现puma_gpt2_benchmark 复现流程: 1)参考https://github.com/secretflow/spu/blob/main/CONTRIBUTING.md#prerequisite,拉取secretflow/ubuntu-base-ci:latest,并启动spu-dev-$(whoami)容器 2)git clone https://github.com/secretflow/spu.git & cd spu进入spu文件夹 3)python3 -m pip install -r requirements.txt python3 -m pip install -r requirements-dev.txt 4)puma_gpt2_benchmark文件夹导入spu/example/python/ml/ 5)pip install 'transformers[flax]' 到这一步都正常,docker内pip list如下 `(base) root@9022304ba02a:/home/admin/dev# pip list Package Version
absl-py 1.4.0 aiohttp 3.9.3 aiosignal 1.3.1 archspec 0.2.1 array-record 0.5.0 async-timeout 4.0.3 attrs 23.2.0 beautifulsoup4 4.12.3 boltons 23.0.0 Brotli 1.0.9 cachetools 5.3.3 certifi 2023.7.22 cffi 1.15.1 charset-normalizer 2.0.4 chex 0.1.82 click 8.1.7 cloudpickle 3.0.0 conda 23.10.0 conda-content-trust 0.2.0 conda-libmamba-solver 23.11.1 conda-package-handling 2.2.0 conda_package_streaming 0.9.0 cryptography 41.0.3 datasets 2.18.0 dill 0.3.8 dm-tree 0.1.8 etils 1.7.0 evaluate 0.4.1 filelock 3.13.1 flax 0.7.0 frozenlist 1.4.1 fsspec 2024.2.0 googleapis-common-protos 1.62.0 grpcio 1.62.0 huggingface-hub 0.21.3 idna 3.4 importlib_resources 6.1.2 jax 0.4.13 jaxlib 0.4.13 Jinja2 3.1.3 jsonpatch 1.32 jsonpointer 2.1 libmambapy 1.5.3 markdown-it-py 3.0.0 MarkupSafe 2.1.5 mdurl 0.1.2 ml-dtypes 0.3.2 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.16 nest-asyncio 1.6.0 networkx 3.2.1 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.3.101 nvidia-nvtx-cu12 12.1.105 opt-einsum 3.3.0 optax 0.1.4 orbax-checkpoint 0.5.3 packaging 23.1 pandas 2.2.1 pillow 10.2.0 pip 23.3 pluggy 1.0.0 promise 2.3 protobuf 3.20.3 psutil 5.9.8 pyarrow 15.0.0 pyarrow-hotfix 0.6 pycosat 0.6.6 pycparser 2.21 Pygments 2.17.2 pyOpenSSL 23.2.0 PySocks 1.7.1 python-dateutil 2.9.0.post0 pytz 2024.1 PyYAML 6.0.1 regex 2023.12.25 requests 2.31.0 responses 0.18.0 rich 13.7.1 ruamel.yaml 0.17.21 ruamel.yaml.clib 0.2.6 safetensors 0.4.2 scipy 1.12.0 setuptools 68.0.0 six 1.16.0 soupsieve 2.5 sympy 1.12 tensorflow-datasets 4.9.4 tensorflow-metadata 1.14.0 tensorstore 0.1.54 termcolor 2.4.0 tokenizers 0.15.2 toml 0.10.2 toolz 0.12.1 torch 2.2.1 torchaudio 2.2.1 torchvision 0.17.1 tqdm 4.65.0 transformers 4.38.2 triton 2.2.0 truststore 0.8.0 typing_extensions 4.10.0 tzdata 2024.1 urllib3 1.26.18 wheel 0.41.2 wrapt 1.16.0 xxhash 3.4.1 yarl 1.9.4 zipp 3.17.0 zstandard 0.19.0
6)bazel run -c opt //examples/python/utils:nodectl -- --config
pwd`/examples/python/ml/puma_bert_benchmarks/3pc.json upStandalone code to reproduce the issue
Relevant log output