secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
230 stars 99 forks source link

[Question]: Evaluating the model with significantly reduced number of ReLUs receives limited communication savings #701

Closed warpoons closed 3 months ago

warpoons commented 4 months ago

Issue Type

Performance

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0.dev20240311

OS Platform and Distribution

Ubuntu 18.04.6 LTS by WSL

Python Version

3.10

Compiler Version

GCC 11.3.0

Current Behavior?

Hi! Dear SPU developers. Thanks for your great work. Currently, I am interested in PPML optimization and have tested the comm and time cost of private resnet18 on cifar100 using SPU. And also, considering the non-linear computationas are always the difficulty in MPC. There are some existing papers such as DeepReDuce, SNL, SENet and AutoReP that improve PPML efficiency by reducing the number of ReLUs. Therefore, I have also tested the comm and time cost of private resnet18 on cifar10 but with significantly fewer ReLUs (50000 ReLUs only). But it observed very limited comm reduction and increased inference latency. The original resnet18 on cifar100 has 491520 ReLUs. Reducing the ReLU count to 50000 (~10x) did not receive the expected efficiency improvement. I am wondering the reason behind this.

In addition, the inference latency has been increased from approximately 0.5 seconds to 3.5 seconds. Is this because replacing ReLUs with identity function in pixel-wise breaks the parallelism of the original calculation?

Thanks!

Standalone code to reproduce the issue

N/A

Relevant log output

The logs of original and optimized models:
========== Optimized model ==========
[2024-05-23 14:30:05.280] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.6538e-05s, execution took 3.466021983s, output processing took 4.8568e-05s, total time 3.466087089s.
[2024-05-23 14:30:05.367] [info] [api.cc:209] HLO profiling: total time 0.013780189999999998
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.free, executed 176196 times, duration 0.006073516s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.reshape, executed 34541 times, duration 0.001168313s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.dynamic_slice, executed 34740 times, duration 0.001163929s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.and, executed 23339 times, duration 0.000880431s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.less, executed 23252 times, duration 0.00080472s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.not, executed 23196 times, duration 0.000793692s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.add, executed 20052 times, duration 0.000715629s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.slice, executed 19326 times, duration 0.000653541s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.select, executed 11723 times, duration 0.000400804s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.greater, executed 11629 times, duration 0.000395687s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.dynamic_update_slice, executed 11580 times, duration 0.000394821s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.concatenate, executed 3980 times, duration 0.000137244s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.reduce, executed 3930 times, duration 0.000133821s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.constant, executed 197 times, duration 9.532e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.multiply, executed 226 times, duration 9.438e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.shift_right_logical, executed 176 times, duration 6.209e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.xor, executed 166 times, duration 5.701e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.shift_left, executed 160 times, duration 5.632e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.or, executed 160 times, duration 5.604e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.pad, executed 151 times, duration 5.449e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.broadcast, executed 66 times, duration 2.88e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.negate, executed 59 times, duration 2.101e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.8e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.equal, executed 40 times, duration 1.479e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.while, executed 24 times, duration 1.103e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 9.9e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.prefer_a, executed 26 times, duration 9.84e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 9.77e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.12e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.custom_call: pphlo.gather, executed 16 times, duration 8.25e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.convert, executed 18 times, duration 7.11e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.sign, executed 13 times, duration 5.19e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.52e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.iota, executed 9 times, duration 3.89e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.simple_sort, executed 6 times, duration 2.66e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 4.7e-08s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.2e-08s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:209] HAL profiling: total time 2.0239852999999997
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_less, executed 220157 times, duration 1.274664774s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 0.208780164s, send bytes 47269376 recv bytes 47269376
[2024-05-23 14:30:05.368] [info] [api.cc:212] - logical_not, executed 23196 times, duration 0.114765134s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.108781809s, send bytes 2265600 recv bytes 2265600
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _mux, executed 11736 times, duration 0.096550162s, send bytes 786432 recv bytes 786432
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_less, executed 41 times, duration 0.090282452s, send bytes 3196032 recv bytes 3196032
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.050956911s, send bytes 2791424 recv bytes 2791424
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _and, executed 23499 times, duration 0.029725783s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_add, executed 19835 times, duration 0.025670492s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - mixed_mul, executed 37 times, duration 0.013260993s, send bytes 583296 recv bytes 583296
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_add, executed 243 times, duration 0.002835047s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002320518s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _xor, executed 486 times, duration 0.002201032s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.001431062s, send bytes 414496 recv bytes 414496
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _rshift, executed 176 times, duration 0.000546018s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _lshift, executed 160 times, duration 0.000530978s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_negate, executed 59 times, duration 0.000271372s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - _sign, executed 13 times, duration 0.000177216s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_equal, executed 49 times, duration 0.000125857s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - f_equal, executed 4 times, duration 3.939e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_mul, executed 4 times, duration 3.8875e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - seal, executed 2 times, duration 2.9261e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:209] MPC profiling: total time 1.6191874379999998
[2024-05-23 14:30:05.368] [info] [api.cc:212] - extract_slice, executed 815914 times, duration 0.333899862s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - add_pp, executed 532502 times, duration 0.332938109s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 0.171460546s, send bytes 47277568 recv bytes 47277568
[2024-05-23 14:30:05.368] [info] [api.cc:212] - not_p, executed 252057 times, duration 0.136484737s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - msb_p, executed 220170 times, duration 0.10026152s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - make_p, executed 278894 times, duration 0.100209974s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.093930745s, send bytes 1731872 recv bytes 1731872
[2024-05-23 14:30:05.368] [info] [api.cc:212] - msb_a2b, executed 41 times, duration 0.087471444s, send bytes 3196032 recv bytes 3196032
[2024-05-23 14:30:05.368] [info] [api.cc:212] - mul_aa, executed 236 times, duration 0.042093378s, send bytes 3108608 recv bytes 3108608
[2024-05-23 14:30:05.368] [info] [api.cc:212] - update_slice, executed 11580 times, duration 0.040055488s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - concatenate, executed 3996 times, duration 0.028051726s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - a2b, executed 20 times, duration 0.026938062s, send bytes 998400 recv bytes 998400
[2024-05-23 14:30:05.368] [info] [api.cc:212] - b2a, executed 101 times, duration 0.023378065s, send bytes 571776 recv bytes 571776
[2024-05-23 14:30:05.368] [info] [api.cc:212] - and_pp, executed 23499 times, duration 0.019201004s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.01835421s, send bytes 422400 recv bytes 422400
[2024-05-23 14:30:05.368] [info] [api.cc:212] - reshape, executed 51196 times, duration 0.01797403s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - add_aa, executed 6575 times, duration 0.009594719s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - not_a, executed 3205 times, duration 0.008131218s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - pad, executed 151 times, duration 0.00648642s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - mul_pp, executed 8649 times, duration 0.005604596s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - add_ap, executed 3322 times, duration 0.003481125s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - mul_ap, executed 3350 times, duration 0.002645787s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - transpose, executed 4015 times, duration 0.002140406s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - xor_pp, executed 486 times, duration 0.001986499s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.0014216s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - broadcast, executed 4201 times, duration 0.001348527s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.000842458s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000716466s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.000634214s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - rshift_p, executed 176 times, duration 0.000483398s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - lshift_p, executed 160 times, duration 0.000474628s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000206007s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - equal_pp, executed 53 times, duration 0.000125038s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - reverse, executed 11 times, duration 7.3523e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - xor_bp, executed 20 times, duration 6.183e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:212] - p2a, executed 2 times, duration 2.6079e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:30:05.368] [info] [api.cc:222] Link details: total send bytes 57306656, recv bytes 57306656, send actions 1251

========== Original model ==========
[2024-05-23 14:18:09.433] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 0.000332336s, execution took 0.52443393s, output processing took 3.6456e-05s, total time 0.524802722s.
[2024-05-23 14:18:09.433] [info] [api.cc:209] HLO profiling: total time 5.843199999999999e-05
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.free, executed 574 times, duration 2.3527e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 9.884e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.multiply, executed 222 times, duration 9.255e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.constant, executed 30 times, duration 2.571e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.409e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.84e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.greater, executed 41 times, duration 1.797e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.427e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.169e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 1.004e-06s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 9.95e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.67e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 7e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.69e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 2.31e-07s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 9.5e-08s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 4.9e-08s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.3e-08s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:209] HAL profiling: total time 0.47888425500000004
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 0.185054774s, send bytes 47269376 recv bytes 47269376
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_less, executed 41 times, duration 0.113789016s, send bytes 4741632 recv bytes 4741632
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.107866579s, send bytes 2265600 recv bytes 2265600
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.04622714s, send bytes 2791424 recv bytes 2791424
[2024-05-23 14:18:09.433] [info] [api.cc:212] - mixed_mul, executed 37 times, duration 0.013317287s, send bytes 1245696 recv bytes 1245696
[2024-05-23 14:18:09.433] [info] [api.cc:212] - _mux, executed 4 times, duration 0.005759732s, send bytes 786432 recv bytes 786432
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_add, executed 243 times, duration 0.002756277s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002638584s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.001443334s, send bytes 414496 recv bytes 414496
[2024-05-23 14:18:09.433] [info] [api.cc:212] - seal, executed 2 times, duration 3.1532e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:209] MPC profiling: total time 0.5065149670000001
[2024-05-23 14:18:09.433] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 0.143095114s, send bytes 47277568 recv bytes 47277568
[2024-05-23 14:18:09.433] [info] [api.cc:212] - msb_a2b, executed 41 times, duration 0.110412673s, send bytes 4741632 recv bytes 4741632
[2024-05-23 14:18:09.433] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.09560603s, send bytes 1731872 recv bytes 1731872
[2024-05-23 14:18:09.433] [info] [api.cc:212] - mul_aa, executed 236 times, duration 0.042655826s, send bytes 3550208 recv bytes 3550208
[2024-05-23 14:18:09.433] [info] [api.cc:212] - a2b, executed 20 times, duration 0.024979522s, send bytes 998400 recv bytes 998400
[2024-05-23 14:18:09.433] [info] [api.cc:212] - b2a, executed 101 times, duration 0.022307338s, send bytes 792576 recv bytes 792576
[2024-05-23 14:18:09.433] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.019879165s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.017747497s, send bytes 422400 recv bytes 422400
[2024-05-23 14:18:09.433] [info] [api.cc:212] - pad, executed 15 times, duration 0.006463843s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.004837987s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.004792714s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - not_a, executed 85 times, duration 0.004279971s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.002338883s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.00150218s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - add_ap, executed 202 times, duration 0.001281746s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.000841452s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.00072102s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000642717s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.000631263s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.433] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000472091s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - make_p, executed 515 times, duration 0.000379443s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000224111s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - broadcast, executed 103 times, duration 0.00012092s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - add_pp, executed 40 times, duration 8.828e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - xor_bp, executed 20 times, duration 6.5605e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - mul_pp, executed 20 times, duration 6.5301e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - not_p, executed 20 times, duration 4.1226e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - p2a, executed 2 times, duration 2.8136e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:212] - reverse, executed 11 times, duration 1.2913e-05s, send bytes 0 recv bytes 0
[2024-05-23 14:18:09.434] [info] [api.cc:222] Link details: total send bytes 59514656, recv bytes 59514656, send actions 1251
anakinxc commented 4 months ago

@llCurious mind take a look?

llCurious commented 4 months ago

hi, @warpoons

The profiling numbers seems a little bit wired. Basically, the number of ReLU directly impacts the number of corresponding PPHLO (greater and select). As shown in the fig below, the nn.relu() function is translated into combination of greater and select. image

However, as indicated in your profiling outputs, OPTIMIZED: [2024-05-23 14:30:05.367] [info] [api.cc:212] - pphlo.greater, executed 11629 times, duration 0.000395687s, send bytes 0 recv bytes 0

ORIGINAL: [2024-05-23 14:18:09.433] [info] [api.cc:212] - pphlo.greater, executed 41 times, duration 1.797e-06s, send bytes 0 recv bytes 0

The invocation of greater for the optmized on increases a lot compared to the original one, which may answers the slower performance.

While concerning the communication size, the number is decreased from 474w to 319w, ~30% reduction. It sees that the vectorization is broken somewhere, thus leading to more invocations.

OPTIMIZED: [2024-05-23 14:30:05.368] [info] [api.cc:212] - f_less, executed 41 times, duration 0.090282452s, send bytes 3196032 recv bytes 3196032

ORIGINAL: [2024-05-23 14:18:09.433] [info] [api.cc:212] - f_less, executed 41 times, duration 0.113789016s, send bytes 4741632 recv bytes 4741632

Would you give a short description of your optimization method to reduce the number of ReLUs (unstructured maybe?)?

warpoons commented 4 months ago

Hi @llCurious ! Thanks for your answer. Below is the main code for ReLU optimization:

class ChooseActivation(nn.Module):
    ratio: float  # The ratio of ReLU count that we want to retain at this layer
    act_type: int  # Set different values to choose different types of replaced function of ReLU. In my code, act_type=1 corresponds to the identity function, i.e., y=x.
    dtype: jnp.dtype = jnp.float32

    def __call__(self, hidden_input):
        act_count = hidden_input.size  # The activation count for this non-linear layer
        keep_num = math.floor(act_count * self.ratio)
        alpha = get_alpha(hidden_input, keep_num)  # get_alpha is a function which outputs a random binary (1 or 0) array that has the same shape as hidden_input and has keep_num number of 1s, where the positions with alpha=1 are retained as ReLU and with alpha=0 are replaced with identity functions
        relu_indices = jnp.where(alpha == 1, size=keep_num)
        custom_act_indices = jnp.where(alpha == 0, size=alpha.size - keep_num)
        relu_weights = hidden_input[relu_indices[0], relu_indices[1], relu_indices[2], relu_indices[3]]
        custom_act_weights = hidden_input[custom_act_indices[0], custom_act_indices[1], custom_act_indices[2], custom_act_indices[3]]
        relu_out = nn.relu(relu_weights)
        custom_act_out = custom_act(custom_act_weights, act_type=self.act_type)
        out = hidden_input.at[relu_indices[0], relu_indices[1], relu_indices[2], relu_indices[3]].set(relu_out)
        out = out.at[custom_act_indices[0], custom_act_indices[1], custom_act_indices[2], custom_act_indices[3]].set(custom_act_out)
        return out

and in the model definition models.py, I modified the definition of ResNet18 as (modifications are marked with ★★★):

class ResNet(nn.Module):
    """ResNetV1."""

    stage_sizes: Sequence[int]
    block_cls: ModuleDef
    num_classes: int
    ratio: float ★★★
    act_type: int ★★
    num_filters: int = 64
    dtype: Any = jnp.float32
    conv: ModuleDef = nn.Conv

    @nn.compact
    def __call__(self, x, train: bool = True):
        conv = partial(self.conv, use_bias=False, dtype=self.dtype)
        norm = partial(
            nn.BatchNorm,
            use_running_average=not train,
            momentum=0.9,
            epsilon=1e-5,
            dtype=self.dtype,
        )

        costom_act = ChooseActivation(ratio=self.ratio, act_type=self.act_type) ★★★

        x = conv(
            self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name="conv_init"
        )(x)
        x = norm(name="bn_init")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(
                    self.num_filters * 2**i,
                    strides=strides,
                    conv=conv,
                    norm=norm,
                    act=costom_act, ★★★
                )(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        return x

Note that in this code, I did not load specific pre-trained model weights or optimized alpha values. So the alphais completely random and this is a rough estimation for comm and time costs under a given ReLU reduction ratio. And I think this will NOT significantly affect the test precision. (Please correct me if this is wrong).

Thanks.

warpoons commented 4 months ago

Hi @llCurious ! Is there any errors or mistakes that I made in the above code? Please directly point them out! Thank your all!

It seems many more pphlo.greater are invocated and nearly ~10x ReLU reduction only leads to ~30% comm reduction of f_less.

llCurious commented 4 months ago

Yep, this optimization kinda looks like unstructured ReLU, with random positions selected to perform cheaper activation function, while the left still using ReLU.

  1. What's the impl for get_alpha, PRG in PPU seems to have wierd behaviors (I may need to check this)

    get_alpha(hidden_input, keep_num)

  2. Concerning the following code snippet. This should not cause the excessive greater invocations.

    relu_weights = hidden_input[relu_indices[0], relu_indices[1], relu_indices[2], relu_indices[3]] relu_out = nn.relu(relu_weights)

I tested the following code, which shall have similar behaviors like your code.

x = np.random.randn(4, 5, 4, 5).astype(np.float16)
mask = np.random.randint(0, 2, size=(4, 5, 4, 5))
keep_num = mask.sum()
relu_indices = jnp.where(mask == 1, size=keep_num)
x_m = x[relu_indices[0], relu_indices[1], relu_indices[2], relu_indices[3]]
spu_fn = ppsim.sim_jax(sim, lambda x: nn.relu(x), copts=copts)
z = spu_fn(x_m)

And the corresponding DAG is image

  1. Do you test the training or inference of ResNet18?
warpoons commented 4 months ago

Hi @llCurious ! Thanks for your patient answer.

  1. The implement of get_alpha:

    @partial(jax.jit, static_argnums=(1,))
    def get_alpha(fea, x):
    shape = fea.shape
    out = jnp.zeros(shape)
    key = jax.random.PRNGKey(42)
    idx = jax.random.choice(key, out.size, shape=(x,), replace=False)
    out = out.at[jnp.unravel_index(idx, out.shape)].set(1)
    return out

    I am not sure about whether PRG could impact the test performance. And I add the @partial(jax.jit, static_argnums=(1,)) line to avoid the error jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. Otherwise this error will happen.

  2. The implement of custom_act:

    def custom_act(x, act_type):
    out = None
    if act_type == 1:  # Linear
        out = x
    elif act_type == 2:  # Square
        out = x ** 2
    elif act_type == 3:  # 2-order polynomial
        out = x ** 2 + x
    elif act_type == 4:  # 3-order polynomial
        out = x ** 3 + x ** 2 + x
    return out
  3. I am not sure if there is any contradictory observations between your code and mine?

  4. Yes. I have tested the inference of ResNet18 on cifar10. The log is exactly the one I reported in ORIGINAL.

By the way, the test is evaluated under 2PC.json setting as below:

{
    "id": "colocated.2pc",
    "nodes": {
        "node:0": "127.0.0.1:61320",
        "node:1": "127.0.0.1:61321"
    },
    "devices": {
        "SPU": {
            "kind": "SPU",
            "config": {
                "node_ids": [
                    "node:0",
                    "node:1"
                ],
                "experimental_data_folder": [
                    "/tmp/spu_data_0/",
                    "/tmp/spu_data_1/"
                ],
                "spu_internal_addrs": [
                    "127.0.0.1:61330",
                    "127.0.0.1:61331"
                ],
                "runtime_config": {
                    "protocol": "SEMI2K",
                    "field": "FM64",
                    "enable_pphlo_profile": true,
                    "enable_hal_profile": true
                }
            }
        },
        "P1": {
            "kind": "PYU",
            "config": {
                "node_id": "node:0"
            }
        },
        "P2": {
            "kind": "PYU",
            "config": {
                "node_id": "node:1"
            }
        }
    }
}
warpoons commented 4 months ago

I have further tested get_alpha with no PRG used as below (with just the fixed front keep_num number of alphas are set to 1):

def get_alpha(fea, x):
    shape = fea.shape
    out = jnp.zeros(shape)
    flat_out = out.reshape(-1)
    flat_out = flat_out.at[:x].set(1)
    alpha = flat_out.reshape(shape)
    return alpha

and the comm is not changed. The log is:

[2024-05-24 11:31:59.378] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.6257e-05s, execution took 5.175019927s, output processing took 4.1312e-05s, total time 5.175077496s.
[2024-05-24 11:31:59.534] [info] [api.cc:209] HLO profiling: total time 0.06917814300000003
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.free, executed 869188 times, duration 0.030531211s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.reshape, executed 226383 times, duration 0.00791738s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.slice, executed 149800 times, duration 0.005204786s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.and, executed 138298 times, duration 0.004942197s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.dynamic_slice, executed 138240 times, duration 0.004759656s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.less, executed 92260 times, duration 0.003303312s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.not, executed 92192 times, duration 0.003218248s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.add, executed 61927 times, duration 0.002260445s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.select, executed 46194 times, duration 0.001664652s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.greater, executed 46121 times, duration 0.001636872s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.dynamic_update_slice, executed 46080 times, duration 0.001580741s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.reduce, executed 30750 times, duration 0.00106447s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.concatenate, executed 30792 times, duration 0.001056983s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.multiply, executed 222 times, duration 9.485e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.constant, executed 144 times, duration 6.775e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.pad, executed 143 times, duration 5.31e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.negate, executed 58 times, duration 2.081e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.048e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.while, executed 40 times, duration 1.807e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.795e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.custom_call: pphlo.gather, executed 32 times, duration 1.593e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.equal, executed 32 times, duration 1.199e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 1.052e-06s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 9.62e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 8.85e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.prefer_a, executed 20 times, duration 7.79e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.shift_right_logical, executed 12 times, duration 5.16e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.62e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.sign, executed 6 times, duration 2.62e-07s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 8.9e-08s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 4.8e-08s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.2e-08s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:209] HAL profiling: total time 2.4410561299999998
[2024-05-24 11:31:59.534] [info] [api.cc:212] - i_less, executed 138340 times, duration 0.792062239s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - logical_not, executed 92192 times, duration 0.490369544s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - _mux, executed 46200 times, duration 0.449080506s, send bytes 786432 recv bytes 786432
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 0.188781209s, send bytes 47269376 recv bytes 47269376
[2024-05-24 11:31:59.534] [info] [api.cc:212] - _and, executed 138298 times, duration 0.185906299s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.102207312s, send bytes 2265600 recv bytes 2265600
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_less, executed 41 times, duration 0.087373609s, send bytes 3196032 recv bytes 3196032
[2024-05-24 11:31:59.534] [info] [api.cc:212] - i_add, executed 61704 times, duration 0.076843793s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.04842032s, send bytes 2791424 recv bytes 2791424
[2024-05-24 11:31:59.534] [info] [api.cc:212] - mixed_mul, executed 37 times, duration 0.012125704s, send bytes 583296 recv bytes 583296
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_add, executed 243 times, duration 0.002615094s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002343434s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.001772994s, send bytes 414496 recv bytes 414496
[2024-05-24 11:31:59.534] [info] [api.cc:212] - i_negate, executed 58 times, duration 0.000661709s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - i_equal, executed 38 times, duration 0.000236268s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - _sign, executed 6 times, duration 0.000169445s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - _rshift, executed 12 times, duration 5.589e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.534] [info] [api.cc:212] - seal, executed 2 times, duration 3.0761e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:209] MPC profiling: total time 2.2105977940000003
[2024-05-24 11:31:59.535] [info] [api.cc:212] - add_pp, executed 569306 times, duration 0.388091665s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - update_slice, executed 46080 times, duration 0.291837406s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - extract_slice, executed 415888 times, duration 0.224614733s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - not_p, executed 246092 times, duration 0.162447823s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 0.144766526s, send bytes 47277568 recv bytes 47277568
[2024-05-24 11:31:59.535] [info] [api.cc:212] - make_p, executed 369511 times, duration 0.13744379s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - and_pp, executed 138298 times, duration 0.121159806s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - reshape, executed 323450 times, duration 0.116780308s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.101283116s, send bytes 1731872 recv bytes 1731872
[2024-05-24 11:31:59.535] [info] [api.cc:212] - msb_a2b, executed 41 times, duration 0.084929278s, send bytes 3196032 recv bytes 3196032
[2024-05-24 11:31:59.535] [info] [api.cc:212] - concatenate, executed 30808 times, duration 0.075619358s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - msb_p, executed 138346 times, duration 0.067429922s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - add_aa, executed 61775 times, duration 0.051391762s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - not_a, executed 30805 times, duration 0.049552572s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - mul_aa, executed 236 times, duration 0.03643003s, send bytes 3108608 recv bytes 3108608
[2024-05-24 11:31:59.535] [info] [api.cc:212] - a2b, executed 20 times, duration 0.025412138s, send bytes 998400 recv bytes 998400
[2024-05-24 11:31:59.535] [info] [api.cc:212] - add_ap, executed 30922 times, duration 0.025202628s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - mul_ap, executed 30950 times, duration 0.021487218s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - b2a, executed 101 times, duration 0.021317613s, send bytes 571776 recv bytes 571776
[2024-05-24 11:31:59.535] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.017673356s, send bytes 422400 recv bytes 422400
[2024-05-24 11:31:59.535] [info] [api.cc:212] - transpose, executed 30811 times, duration 0.012870936s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - mul_pp, executed 15502 times, duration 0.010822609s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - broadcast, executed 30930 times, duration 0.009768128s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - pad, executed 143 times, duration 0.008192407s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.00138406s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.000826662s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000682004s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.000622894s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - equal_pp, executed 38 times, duration 0.000203896s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000203057s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - xor_bp, executed 20 times, duration 6.2716e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - rshift_p, executed 12 times, duration 5.0127e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - p2a, executed 2 times, duration 2.7119e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:212] - reverse, executed 11 times, duration 1.0131e-05s, send bytes 0 recv bytes 0
[2024-05-24 11:31:59.535] [info] [api.cc:222] Link details: total send bytes 57306656, recv bytes 57306656, send actions 1251
llCurious commented 4 months ago

Hey, @warpoons .

I think the slower efficiency may be due to two reasons:

  1. As shown in the log below, the unstructured ReLU introduces a lot local computations, which does not require communication. Since according to your runtime, you should run the test on one single machine with negligible RTT and high bandwidth, the local computation constitute a large ratio.
HAL profiling: total time 2.0239852999999997
[2024-05-23 14:30:05.368] [info] [api.cc:212] - i_less, executed 220157 times, duration 1.274664774s, send bytes 0 recv bytes 0
  1. The communication size for f_less did decrease by about 30%, which does not meet the expected 10x reduction. I think you may need to figure out the communication size ratio of your modified ReLU in the total Resnet18.
warpoons commented 4 months ago

Hi, @llCurious I have just tested the ResNet18 WITHOUT ReLUs (by muting all the definitions of activation layers in models.py) and the comm. cost is 55746336 bytes and time cost is 0.457607488 s. Compared to the original version of 59514656 bytes and 0.52443393 s, it only decreases by ~6.33% in comm. cost. Is that means the ReLU computations only constitute a very small ratio of comm. cost in the total ResNet18? And the ReLUs' time cost is also only a tiny part of total latency.

However, as suggested in previous papers such as DeepReDuce, SNL, SENet and AutoReP, which argue that the ReLUs are the bottleneck of PPML and may constitute up to >90% of the time cost. They almostly use Delphi or CrypTen as their baseline PPML protocol.

Is there any highly-optimized ReLU computations in SPU which make ReLU extremely efficient?

One additional question is that: is SPU follows an offline-online two-stage framework just like Delphi, where the offline stage is used to pre-generate some input-independent data like Beaver’s multiplicative triples for future use in online stage? If NOT, is the comm. size we observed in the logs generated by SPU somewhat similar to the online stage in Delphi? Or similar to sum of the online and offline stages?

Sorry for taking your time. Thanks for your attention.

llCurious commented 4 months ago

Is there any highly-optimized ReLU computations in SPU which make ReLU extremely efficient?

Nope. ReLU decomposes to standard comparison op and has similar communication complexity against existing works.

Is that means the ReLU computations only constitute a very small ratio of comm. cost in the total ResNet18?

The answer shall be yes according to your profiling. The matrix multiplication consumes a lot comm.

mmul_aa, executed 21 times, duration 0.144766526s, send bytes 47277568 recv bytes 47277568

is SPU follows an offline-online two-stage framework?

The answer is no. The beaver triples are also generated online.

BTW, could you switch to ABY3 protocol and run a test?

warpoons commented 4 months ago

Thanks @llCurious

I have run the ORIGINAL and NO-ReLU ResNet18 on cifar10 using ABY3 protocol. The logs are: For ORIGINAL ResNet18 :

[2024-05-24 16:00:05.475] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.7582e-05s, execution took 0.66206184s, output processing took 3.2597e-05s, total time 0.662112019s.
[2024-05-24 16:00:05.476] [info] [api.cc:209] HLO profiling: total time 6.628e-05
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.free, executed 574 times, duration 2.7306e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.1725e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.multiply, executed 222 times, duration 1.0262e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.419e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.constant, executed 30 times, duration 2.343e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 2.251e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.greater, executed 41 times, duration 1.974e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.58e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.281e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 1.171e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 1.122e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 1.044e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 7.82e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 5.15e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 3.16e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 9.8e-08s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 5.2e-08s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 3.9e-08s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:209] HAL profiling: total time 0.547479285
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.185890175s, send bytes 6399488 recv bytes 6370816
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_less, executed 41 times, duration 0.168771585s, send bytes 5757696 recv bytes 5757696
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.076255034s, send bytes 7446528 recv bytes 7219200
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 0.070519074s, send bytes 2842624 recv bytes 3235840
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mixed_mul, executed 37 times, duration 0.018771326s, send bytes 2491392 recv bytes 830464
[2024-05-24 16:00:05.476] [info] [api.cc:212] - _mux, executed 4 times, duration 0.012148933s, send bytes 1572864 recv bytes 524288
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_add, executed 243 times, duration 0.009367549s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.005110175s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.000636081s, send bytes 4800 recv bytes 8000
[2024-05-24 16:00:05.476] [info] [api.cc:212] - seal, executed 2 times, duration 9.353e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:209] MPC profiling: total time 0.639319701
[2024-05-24 16:00:05.476] [info] [api.cc:212] - msb_a2b, executed 41 times, duration 0.16308444s, send bytes 5757696 recv bytes 5757696
[2024-05-24 16:00:05.476] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.093467571s, send bytes 9256064 recv bytes 9197824
[2024-05-24 16:00:05.476] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.074204599s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - b2a, executed 60 times, duration 0.0719904s, send bytes 2701312 recv bytes 2899968
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 0.062236173s, send bytes 812608 recv bytes 812608
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.043657938s, send bytes 2195456 recv bytes 2195456
[2024-05-24 16:00:05.476] [info] [api.cc:212] - a2b, executed 20 times, duration 0.026882618s, send bytes 1228800 recv bytes 1228800
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mul_a1b, executed 41 times, duration 0.024214s, send bytes 4064256 recv bytes 1354752
[2024-05-24 16:00:05.476] [info] [api.cc:212] - and_bb, executed 140 times, duration 0.017942575s, send bytes 499200 recv bytes 499200
[2024-05-24 16:00:05.476] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.01515361s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - pad, executed 15 times, duration 0.014491606s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - not_a, executed 85 times, duration 0.007386165s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.006379016s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.00403507s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.003631834s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - add_ap, executed 202 times, duration 0.003259126s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - xor_bb, executed 680 times, duration 0.001998487s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - and_bp, executed 420 times, duration 0.001309539s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - rshift_b, executed 420 times, duration 0.001228294s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000951528s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - make_p, executed 555 times, duration 0.000472903s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - lshift_b, executed 140 times, duration 0.00047071s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000322811s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - broadcast, executed 103 times, duration 0.000180465s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - add_pp, executed 40 times, duration 0.000118248s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - xor_bp, executed 20 times, duration 8.8251e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - mul_pp, executed 20 times, duration 8.2109e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - not_p, executed 20 times, duration 5.8151e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - reverse, executed 11 times, duration 1.5938e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:212] - p2a, executed 2 times, duration 5.526e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:00:05.476] [info] [api.cc:222] Link details: total send bytes 26515392, recv bytes 23946304, send actions 2143

For NO-ReLU ResNet18 :

[2024-05-24 16:01:25.133] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 4.3236e-05s, execution took 0.61860065s, output processing took 3.3559e-05s, total time 0.618677445s.
[2024-05-24 16:01:25.133] [info] [api.cc:209] HLO profiling: total time 7.085700000000002e-05
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.free, executed 536 times, duration 2.6704e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.4871e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.multiply, executed 205 times, duration 9.766e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.constant, executed 26 times, duration 3.565e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 3.145e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.712e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 2.447e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.297e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.greater, executed 24 times, duration 1.252e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 1.068e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 1.057e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 1.044e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 7.74e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 5.16e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 3.24e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 1.68e-07s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 9.5e-08s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 5.2e-08s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:209] HAL profiling: total time 0.501292139
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.192190658s, send bytes 6291968 recv bytes 6588928
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 0.106236035s, send bytes 2646016 recv bytes 3629056
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_less, executed 24 times, duration 0.083227637s, send bytes 2554624 recv bytes 2554624
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.080034967s, send bytes 7389184 recv bytes 7333888
[2024-05-24 16:01:25.133] [info] [api.cc:212] - _mux, executed 4 times, duration 0.01392067s, send bytes 1572864 recv bytes 524288
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_add, executed 243 times, duration 0.010114796s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - mixed_mul, executed 20 times, duration 0.008576544s, send bytes 230400 recv bytes 76800
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.006171496s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.000809162s, send bytes 4800 recv bytes 8000
[2024-05-24 16:01:25.133] [info] [api.cc:212] - seal, executed 2 times, duration 1.0174e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:209] MPC profiling: total time 0.5924833679999999
[2024-05-24 16:01:25.133] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.099327383s, send bytes 8883328 recv bytes 9943296
[2024-05-24 16:01:25.133] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 0.096493934s, send bytes 812608 recv bytes 812608
[2024-05-24 16:01:25.133] [info] [api.cc:212] - msb_a2b, executed 24 times, duration 0.077597637s, send bytes 2554624 recv bytes 2554624
[2024-05-24 16:01:25.133] [info] [api.cc:212] - b2a, executed 60 times, duration 0.074110507s, send bytes 2712576 recv bytes 2880512
[2024-05-24 16:01:25.133] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.069110716s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.133] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.042210682s, send bytes 2195456 recv bytes 2195456
[2024-05-24 16:01:25.133] [info] [api.cc:212] - a2b, executed 20 times, duration 0.029781377s, send bytes 1228800 recv bytes 1228800
[2024-05-24 16:01:25.134] [info] [api.cc:212] - and_bb, executed 140 times, duration 0.019114465s, send bytes 499200 recv bytes 499200
[2024-05-24 16:01:25.134] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.017209482s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - pad, executed 15 times, duration 0.016937357s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - mul_a1b, executed 24 times, duration 0.015316059s, send bytes 1803264 recv bytes 601088
[2024-05-24 16:01:25.134] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.008332834s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - not_a, executed 68 times, duration 0.008101905s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.004648889s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.00410769s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - add_ap, executed 168 times, duration 0.00240817s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - xor_bb, executed 680 times, duration 0.002083383s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - and_bp, executed 420 times, duration 0.001318771s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - rshift_b, executed 420 times, duration 0.001076695s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.001055957s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - make_p, executed 538 times, duration 0.000702579s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - lshift_b, executed 140 times, duration 0.000481926s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000385667s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - broadcast, executed 99 times, duration 0.000177755s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - add_pp, executed 40 times, duration 0.000112732s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - mul_pp, executed 20 times, duration 0.000102718s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - xor_bp, executed 20 times, duration 9.1196e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - not_p, executed 20 times, duration 5.9467e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - reverse, executed 11 times, duration 1.9377e-05s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:212] - p2a, executed 2 times, duration 6.058e-06s, send bytes 0 recv bytes 0
[2024-05-24 16:01:25.134] [info] [api.cc:222] Link details: total send bytes 20689856, recv bytes 20715584, send actions 1942

ABY3 (22.8 MB comm. cost) seems more comm-efficient than SEMI2K (56.8 MB comm. cost).

But similar to SEMI2K (~6.33%), the ReLUs are still only constitute a small ratio (~13.49%) of comm. cost in total.

As you point out that the beaver triples are also generated online in SPU, this makes it reasonable for the small comm/time cost of ReLUs in total because generating beaver triples is extremely comm-intensive and time-consuming compared to ReLUs. Maybe ReLU is more important when we only talking about the online inference stage.

Motivated by this, I have some additional questions:

  1. Since the beaver triples are also generated online in SPU, can I consider the comm size in SPU as the sum of offline+online comm size in Delphi or CrypTFlow2?

  2. But by comparison, the sum of offline+online comm size of ResNet18-cifar10/100 in Delphi or CrypTFlow2 can reach up to tens of GB, while that of SPU is only 20~50MB. Is this really true and comparable?

  3. Again, since the beaver triples are also generated online in SPU, is there any way to estimate the comm cost of beaver triples' generation from the profiling? So the left comm. can be considered as the "online" comm cost in SPU?

Anyway, thank you soooo much.

llCurious commented 4 months ago

Sorry for the delay.

According to the profiling of both SEMI2k and ABY3, I think the take-away messages are:

  1. First, ReLU does not consume that much overhead in ResNet, compared to matmuls.
  2. Second, I would recommend you to use Cheetah for 2PC and ABY3 for 3PC.

The answers to your questions are as follows.

can I consider the comm size in SPU as the sum of offline+online comm size in Delphi or CrypTFlow2?

No. The protocol constructions of Semi2k are different from these two works in generating correlated randomness (in semi2k, these expensive randomness like beaver triples are generated by some trusted party). ref:https://www.secretflow.org.cn/en/docs/spu/0.9.0b2/reference/mpc_status#supported-mpc-protocol

the sum of offline+online comm size of ResNet18-cifar10/100 in Delphi or CrypTFlow2 can reach up to tens of GB, while that of SPU is only 20~50MB. Is this really true and comparable?

Like said before, semi2k is not comparable to these two works, which use techs like OT to generate randomness, which is far more expensive than using trusted party.

is there any way to estimate the comm cost of beaver triples' generation from the profiling?

I think currently there is no way to directly doing this.

warpoons commented 3 months ago

Hi @llCurious. Thanks for your response!

I have also benchmarked the ResNet18 on CIFAR100 using Cheetah protocol for 2PC in SPU. Here is the log:

[2024-06-03 13:33:37.612] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.7913e-05s, execution took 9.814508076s, output processing took 4.9612e-05s, total time 9.814575601s.
[2024-06-03 13:33:37.612] [info] [api.cc:209] HLO profiling: total time 5.797200000000001e-05
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.free, executed 536 times, duration 2.2729e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.044e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.multiply, executed 205 times, duration 8.808e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.constant, executed 26 times, duration 2.522e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.355e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.945e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.593e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.16e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.greater, executed 24 times, duration 1.135e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 1.013e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.58e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 9.42e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 8.91e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 6.79e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.71e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 2.35e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.9e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 4.7e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:209] HAL profiling: total time 9.767303477
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 6.551057405s, send bytes 29423174 recv bytes 27147221
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_less, executed 24 times, duration 2.085487991s, send bytes 2596888 recv bytes 2356893
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.687243677s, send bytes 6089600 recv bytes 4757778
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.374008647s, send bytes 17346411 recv bytes 20129619
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.048128501s, send bytes 439609 recv bytes 439295
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mixed_mul, executed 20 times, duration 0.008901305s, send bytes 39040 recv bytes 39040
[2024-06-03 13:33:37.612] [info] [api.cc:212] - _mux, executed 4 times, duration 0.006759259s, send bytes 266240 recv bytes 266240
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_add, executed 243 times, duration 0.003040434s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002617991s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - seal, executed 2 times, duration 5.8267e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:209] MPC profiling: total time 9.79187834
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 4.14475303s, send bytes 23819039 recv bytes 23818860
[2024-06-03 13:33:37.612] [info] [api.cc:212] - trunc_a, executed 346 times, duration 2.509568303s, send bytes 6540832 recv bytes 3788968
[2024-06-03 13:33:37.612] [info] [api.cc:212] - msb_a2b, executed 24 times, duration 2.083226061s, send bytes 2596888 recv bytes 2356893
[2024-06-03 13:33:37.612] [info] [api.cc:212] - a2b, executed 20 times, duration 0.50009818s, send bytes 998400 recv bytes 1264659
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.39771609s, send bytes 19410923 recv bytes 23150786
[2024-06-03 13:33:37.612] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.063048469s, send bytes 385200 recv bytes 385200
[2024-06-03 13:33:37.612] [info] [api.cc:212] - b2a, executed 60 times, duration 0.032670144s, send bytes 2144400 recv bytes 65440
[2024-06-03 13:33:37.612] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.018114326s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_a1b, executed 24 times, duration 0.013046298s, send bytes 305280 recv bytes 305280
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pad, executed 15 times, duration 0.006005943s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.005346207s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.004904623s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - not_a, executed 68 times, duration 0.003596807s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.002599442s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.001542413s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_ap, executed 168 times, duration 0.000897028s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.00088169s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000797368s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.000698025s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000661371s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000501438s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - make_p, executed 498 times, duration 0.000464443s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000223616s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - broadcast, executed 99 times, duration 0.000152052s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_pp, executed 40 times, duration 9.26e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_pp, executed 20 times, duration 8.0166e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - xor_bp, executed 20 times, duration 7.0591e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - not_p, executed 20 times, duration 5.6533e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - p2a, executed 2 times, duration 5.4546e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - reverse, executed 11 times, duration 1.0537e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:222] Link details: total send bytes 56200962, recv bytes 55136086, send actions 19799

We can see that for Cheetah it takes almost 9.79s for inference and 56200962 bytes for communication. Those of SEMI2K are 0.51s and 59514656 bytes. They have similar comm. sizes but significantly differ in latency.

As noted here, SEMI2K is much faster than Cheetah because of an additional trusted thrid-party.

Compared to CrypTen which may need 1.242 for inferring ResNet18 on CIFAR100, is Cheetah slower than CrypTen because Cheetah is a complete 2PC setting and CrypTen is a trusted-thrid-party-based way more like SEMI2K?

Compared to the report on Cheetah in a recent paper, Cheetah on ResNet18/CIFAR100 needs 362 MB for communication as below, which is 6.75x larger than the result I tested here. image I noted that this paper use OpenCheetah for evaluating Cheetah, is the difference comes from the optimized implemention in SPU compared to the original version of Cheetah?

Additionally, I find that the results of Cheetah in SPU can be different when I re-run the inference without restarting the SPU backend runtime. The log is:

[2024-06-03 13:57:54.744] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.717e-05s, execution took 4.265889505s, output processing took 1.2899e-05s, total time 4.265919574s.
[2024-06-03 13:57:54.744] [info] [api.cc:209] HLO profiling: total time 5.542699999999999e-05
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.free, executed 536 times, duration 2.2444e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.0638e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.multiply, executed 205 times, duration 8.691e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.366e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.862e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.481e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.constant, executed 26 times, duration 1.223e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.17e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.greater, executed 24 times, duration 1.133e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 9.75e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.66e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 9.2e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 6.86e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.67e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 2.16e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 9.4e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 5.3e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.2e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:209] HAL profiling: total time 4.218076975
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 2.987290693s, send bytes 13817211 recv bytes 13671119
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.778841864s, send bytes 9854935 recv bytes 9263784
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.230024506s, send bytes 13183362 recv bytes 14817404
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_less, executed 24 times, duration 0.155849388s, send bytes 2596888 recv bytes 493080
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.046947478s, send bytes 439606 recv bytes 439388
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mixed_mul, executed 20 times, duration 0.007028804s, send bytes 39040 recv bytes 39040
[2024-06-03 13:57:54.744] [info] [api.cc:212] - _mux, executed 4 times, duration 0.006409118s, send bytes 266240 recv bytes 266240
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_add, executed 243 times, duration 0.00310129s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002538692s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - seal, executed 2 times, duration 4.5142e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:209] MPC profiling: total time 4.247430992
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 3.018374953s, send bytes 14104449 recv bytes 14104155
[2024-06-03 13:57:54.744] [info] [api.cc:212] - a2b, executed 20 times, duration 0.536903278s, send bytes 1264659 recv bytes 1264659
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.313157982s, send bytes 18746950 recv bytes 22344577
[2024-06-03 13:57:54.744] [info] [api.cc:212] - msb_a2b, executed 24 times, duration 0.153236289s, send bytes 2596888 recv bytes 493080
[2024-06-03 13:57:54.744] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.069195494s, send bytes 649456 recv bytes 27664
[2024-06-03 13:57:54.744] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.064483911s, send bytes 385200 recv bytes 385200
[2024-06-03 13:57:54.744] [info] [api.cc:212] - b2a, executed 60 times, duration 0.030637183s, send bytes 2144400 recv bytes 65440
[2024-06-03 13:57:54.744] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.018766642s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_a1b, executed 24 times, duration 0.010929053s, send bytes 305280 recv bytes 305280
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pad, executed 15 times, duration 0.008190051s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.005355921s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.005247071s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - not_a, executed 68 times, duration 0.003782269s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.002618001s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.001383012s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_ap, executed 168 times, duration 0.000878804s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.000835866s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000669757s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.00063207s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000623071s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000479451s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - make_p, executed 498 times, duration 0.000382183s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000201614s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - broadcast, executed 99 times, duration 0.00014712s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_pp, executed 40 times, duration 8.4434e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_pp, executed 20 times, duration 7.0332e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - xor_bp, executed 20 times, duration 6.7603e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - not_p, executed 20 times, duration 4.5766e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - p2a, executed 2 times, duration 4.0934e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - reverse, executed 11 times, duration 1.0877e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:222] Link details: total send bytes 40197282, recv bytes 38990055, send actions 3393

Note that now the profiling is only 40197282 bytes and 4.24s, which are much lower than the results of first running. Can you tell me why this has changed and which one is the correct result?

Sorry for taking your time. Thanks.

anakinxc commented 3 months ago

Hi @llCurious. Thanks for your response!

I have also benchmarked the ResNet18 on CIFAR100 using Cheetah protocol for 2PC in SPU. Here is the log:

[2024-06-03 13:33:37.612] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.7913e-05s, execution took 9.814508076s, output processing took 4.9612e-05s, total time 9.814575601s.
[2024-06-03 13:33:37.612] [info] [api.cc:209] HLO profiling: total time 5.797200000000001e-05
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.free, executed 536 times, duration 2.2729e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.044e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.multiply, executed 205 times, duration 8.808e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.constant, executed 26 times, duration 2.522e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.355e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.945e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.593e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.16e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.greater, executed 24 times, duration 1.135e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 1.013e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.58e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 9.42e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 8.91e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 6.79e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.71e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 2.35e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.9e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 4.7e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:209] HAL profiling: total time 9.767303477
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 6.551057405s, send bytes 29423174 recv bytes 27147221
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_less, executed 24 times, duration 2.085487991s, send bytes 2596888 recv bytes 2356893
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.687243677s, send bytes 6089600 recv bytes 4757778
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.374008647s, send bytes 17346411 recv bytes 20129619
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.048128501s, send bytes 439609 recv bytes 439295
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mixed_mul, executed 20 times, duration 0.008901305s, send bytes 39040 recv bytes 39040
[2024-06-03 13:33:37.612] [info] [api.cc:212] - _mux, executed 4 times, duration 0.006759259s, send bytes 266240 recv bytes 266240
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_add, executed 243 times, duration 0.003040434s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002617991s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - seal, executed 2 times, duration 5.8267e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:209] MPC profiling: total time 9.79187834
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 4.14475303s, send bytes 23819039 recv bytes 23818860
[2024-06-03 13:33:37.612] [info] [api.cc:212] - trunc_a, executed 346 times, duration 2.509568303s, send bytes 6540832 recv bytes 3788968
[2024-06-03 13:33:37.612] [info] [api.cc:212] - msb_a2b, executed 24 times, duration 2.083226061s, send bytes 2596888 recv bytes 2356893
[2024-06-03 13:33:37.612] [info] [api.cc:212] - a2b, executed 20 times, duration 0.50009818s, send bytes 998400 recv bytes 1264659
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.39771609s, send bytes 19410923 recv bytes 23150786
[2024-06-03 13:33:37.612] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.063048469s, send bytes 385200 recv bytes 385200
[2024-06-03 13:33:37.612] [info] [api.cc:212] - b2a, executed 60 times, duration 0.032670144s, send bytes 2144400 recv bytes 65440
[2024-06-03 13:33:37.612] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.018114326s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_a1b, executed 24 times, duration 0.013046298s, send bytes 305280 recv bytes 305280
[2024-06-03 13:33:37.612] [info] [api.cc:212] - pad, executed 15 times, duration 0.006005943s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.005346207s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.004904623s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - not_a, executed 68 times, duration 0.003596807s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.002599442s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.001542413s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_ap, executed 168 times, duration 0.000897028s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.00088169s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000797368s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.000698025s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000661371s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000501438s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - make_p, executed 498 times, duration 0.000464443s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000223616s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - broadcast, executed 99 times, duration 0.000152052s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - add_pp, executed 40 times, duration 9.26e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - mul_pp, executed 20 times, duration 8.0166e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - xor_bp, executed 20 times, duration 7.0591e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - not_p, executed 20 times, duration 5.6533e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - p2a, executed 2 times, duration 5.4546e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:212] - reverse, executed 11 times, duration 1.0537e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:33:37.612] [info] [api.cc:222] Link details: total send bytes 56200962, recv bytes 55136086, send actions 19799

We can see that for Cheetah it takes almost 9.79s for inference and 56200962 bytes for communication. Those of SEMI2K are 0.51s and 59514656 bytes. They have similar comm. sizes but significantly differ in latency.

As noted here, SEMI2K is much faster than Cheetah because of an additional trusted thrid-party.

Compared to CrypTen which may need 1.242 for inferring ResNet18 on CIFAR100, is Cheetah slower than CrypTen because Cheetah is a complete 2PC setting and CrypTen is a trusted-thrid-party-based way more like SEMI2K?

Compared to the report on Cheetah in a recent paper, Cheetah on ResNet18/CIFAR100 needs 362 MB for communication as below, which is 6.75x larger than the result I tested here. image I noted that this paper use OpenCheetah for evaluating Cheetah, is the difference comes from the optimized implemention in SPU compared to the original version of Cheetah?

Additionally, I find that the results of Cheetah in SPU can be different when I re-run the inference without restarting the SPU backend runtime. The log is:

[2024-06-03 13:57:54.744] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.717e-05s, execution took 4.265889505s, output processing took 1.2899e-05s, total time 4.265919574s.
[2024-06-03 13:57:54.744] [info] [api.cc:209] HLO profiling: total time 5.542699999999999e-05
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.free, executed 536 times, duration 2.2444e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.add, executed 243 times, duration 1.0638e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.multiply, executed 205 times, duration 8.691e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.broadcast, executed 45 times, duration 2.366e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.subtract, executed 40 times, duration 1.862e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reduce, executed 30 times, duration 1.481e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.constant, executed 26 times, duration 1.223e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reshape, executed 27 times, duration 1.17e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.greater, executed 24 times, duration 1.133e-06s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.rsqrt, executed 20 times, duration 9.75e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.convolution, executed 20 times, duration 9.66e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.transpose, executed 21 times, duration 9.2e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.pad, executed 15 times, duration 6.86e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reverse, executed 11 times, duration 4.67e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.select, executed 4 times, duration 2.16e-07s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 9.4e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.reduce_window, executed 1 times, duration 5.3e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pphlo.dot, executed 1 times, duration 4.2e-08s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:209] HAL profiling: total time 4.218076975
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_tensordot, executed 20 times, duration 2.987290693s, send bytes 13817211 recv bytes 13671119
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_rsqrt, executed 20 times, duration 0.778841864s, send bytes 9854935 recv bytes 9263784
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_mul, executed 185 times, duration 0.230024506s, send bytes 13183362 recv bytes 14817404
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_less, executed 24 times, duration 0.155849388s, send bytes 2596888 recv bytes 493080
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_mmul, executed 1 times, duration 0.046947478s, send bytes 439606 recv bytes 439388
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mixed_mul, executed 20 times, duration 0.007028804s, send bytes 39040 recv bytes 39040
[2024-06-03 13:57:54.744] [info] [api.cc:212] - _mux, executed 4 times, duration 0.006409118s, send bytes 266240 recv bytes 266240
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_add, executed 243 times, duration 0.00310129s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - f_sub, executed 40 times, duration 0.002538692s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - seal, executed 2 times, duration 4.5142e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:209] MPC profiling: total time 4.247430992
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mmul_aa, executed 21 times, duration 3.018374953s, send bytes 14104449 recv bytes 14104155
[2024-06-03 13:57:54.744] [info] [api.cc:212] - a2b, executed 20 times, duration 0.536903278s, send bytes 1264659 recv bytes 1264659
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_aa, executed 195 times, duration 0.313157982s, send bytes 18746950 recv bytes 22344577
[2024-06-03 13:57:54.744] [info] [api.cc:212] - msb_a2b, executed 24 times, duration 0.153236289s, send bytes 2596888 recv bytes 493080
[2024-06-03 13:57:54.744] [info] [api.cc:212] - trunc_a, executed 346 times, duration 0.069195494s, send bytes 649456 recv bytes 27664
[2024-06-03 13:57:54.744] [info] [api.cc:212] - and_bb, executed 120 times, duration 0.064483911s, send bytes 385200 recv bytes 385200
[2024-06-03 13:57:54.744] [info] [api.cc:212] - b2a, executed 60 times, duration 0.030637183s, send bytes 2144400 recv bytes 65440
[2024-06-03 13:57:54.744] [info] [api.cc:212] - concatenate, executed 16 times, duration 0.018766642s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_a1b, executed 24 times, duration 0.010929053s, send bytes 305280 recv bytes 305280
[2024-06-03 13:57:54.744] [info] [api.cc:212] - pad, executed 15 times, duration 0.008190051s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - extract_slice, executed 4968 times, duration 0.005355921s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_aa, executed 335 times, duration 0.005247071s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - not_a, executed 68 times, duration 0.003782269s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - reshape, executed 4934 times, duration 0.002618001s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - xor_bb, executed 580 times, duration 0.001383012s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_ap, executed 168 times, duration 0.000878804s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - bitrev_b, executed 40 times, duration 0.000835866s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - and_bp, executed 360 times, duration 0.000669757s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - rshift_b, executed 360 times, duration 0.00063207s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_ap, executed 230 times, duration 0.000623071s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - transpose, executed 91 times, duration 0.000479451s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - make_p, executed 498 times, duration 0.000382183s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - lshift_b, executed 120 times, duration 0.000201614s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - broadcast, executed 99 times, duration 0.00014712s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - add_pp, executed 40 times, duration 8.4434e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - mul_pp, executed 20 times, duration 7.0332e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - xor_bp, executed 20 times, duration 6.7603e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - not_p, executed 20 times, duration 4.5766e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - p2a, executed 2 times, duration 4.0934e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:212] - reverse, executed 11 times, duration 1.0877e-05s, send bytes 0 recv bytes 0
[2024-06-03 13:57:54.744] [info] [api.cc:222] Link details: total send bytes 40197282, recv bytes 38990055, send actions 3393

Note that now the profiling is only 40197282 bytes and 4.24s, which are much lower than the results of first running. Can you tell me why this has changed and which one is the correct result?

Sorry for taking your time. Thanks.

SPU's Cheetah is an optimized implementation compare with OpenCheetah.

Cheetah has a fairly expensive initialization cost, so a slower first run is expected behavior.

@fionser feel free to add more comment here :D

anakinxc commented 3 months ago

Compared to CrypTen which may need 1.242 for inferring ResNet18 on CIFAR100, is Cheetah slower than CrypTen because Cheetah is a complete 2PC setting and CrypTen is a trusted-thrid-party-based way more like SEMI2K?

Yes

fionser commented 3 months ago

我觉得 一个 issue 就是解决一个问题。

warpoons commented 3 months ago

Thank @anakinxc for quickly answering and solving my question!

Thank @fionser for pointing out this lengthy discussion. I will close this issue as completed soon.

Sorry for taking your time. Many thanks :D