secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
242 stars 104 forks source link

[Bug]: 尝试复现examples/python/ml/jax_lr报错 #852

Closed fmshglm closed 2 months ago

fmshglm commented 2 months ago

Issue Type

Usability

Modules Involved

Documentation/Tutorial/Example

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

6.5.0

OS Platform and Distribution

Linux version 5.4.0-90-generic

Python Version

3.10.13

Compiler Version

11.4.0

Current Behavior?

我尝试复现examples/python/ml/jax_lr,步骤如下 (1)新建2个容器,在其中执行 pip install -r requirements.txt pip install 'transformers[flax]' bazel build //examples/python/... -c opt (2)第一个容器执行 cd bazel-bin/examples/python/utils ./nodectl --config /home/admin/dev/examples/python/conf/2pc_semi2k.json up

正常执行,输出为

root@f671cc3f0b9b:/home/admin/dev/bazel-bin/examples/python/utils# ./nodectl --config /home/admin/dev/examples/python/conf/2pc_semi2k.json up [2024-09-12 07:25:14,432] [ForkServerProcess-1] Starting grpc server at 127.0.0.1:64320 [2024-09-12 07:25:14,436] [ForkServerProcess-2] Starting grpc server at 127.0.0.1:64321

(3)第二个容器内执行 cd bazel-bin/examples/python/ml/jax_lr ./jax_lr --config /home/admin/dev//examples/python/conf/2pc_semi2k.json

报如下错误 root@4a8ca561ad75:/home/admin/dev/bazel-bin/examples/python/ml/jax_lr# ./jax_lr --config /home/admin/dev/examples/python/conf/2pc_semi2k.json Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/./jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 193, in ppd.init(conf["nodes"], conf["devices"])#创建HostContext File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 1183, in init _CONTEXT = HostContext(nodes_def, devices_def) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 1103, in init self.devices[name] = SPU(#call line 980 File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 1016, in init results = [future.result() for future in futures] File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 1016, in results = [future.result() for future in futures] File "/root/miniconda3/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.get_result() File "/root/miniconda3/lib/python3.10/concurrent/futures/_base.py", line 403, in get_result raise self._exception File "/root/miniconda3/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, *self.kwargs) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 250, in run return self._call(self._stub.Run, fn, args, **kwargs) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 239, in _call rsp_data = rebuild_messages(rsp_itr.data for rsp_itr in rsp_gen) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 217, in rebuild_messages return b''.join([msg for msg in msgs]) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 217, in return b''.join([msg for msg in msgs]) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 239, in rsp_data = rebuild_messages(rsp_itr.data for rsp_itr in rsp_gen) File "/root/miniconda3/lib/python3.10/site-packages/grpc/_channel.py", line 543, in next return self._next() File "/root/miniconda3/lib/python3.10/site-packages/grpc/_channel.py", line 969, in _next raise self grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with: status = StatusCode.UNAVAILABLE details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:192.168.121.35:80: HTTP proxy returned response code 403" debug_error_string = "UNKNOWN:Error received from peer {created_time:"2024-09-12T07:37:42.473466354+00:00", grpc_status:14, grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:192.168.121.35:80: HTTP proxy returned response code 403"}"

Standalone code to reproduce the issue

(1)新建2个容器,在其中执行
pip install -r requirements.txt
pip install 'transformers[flax]'
bazel build //examples/python/... -c opt
(2)第一个容器执行
cd bazel-bin/examples/python/utils
./nodectl --config /home/admin/dev/examples/python/conf/2pc_semi2k.json up
(3)第二个容器内执行
cd bazel-bin/examples/python/ml/jax_lr
./jax_lr --config /home/admin/dev//examples/python/conf/2pc_semi2k.json

Relevant log output

No response

tpppppub commented 2 months ago

nodectl up 是单机多进程的方式运行,你的 jax_lr 也要运行在第一个容器里。如果你想每个容器模拟一个节点,请参考这个文档,修改 2pc_semi2k.json 里的 ip/port 与你的容器匹配,分别在每个容器里 nodectl start node,然后再运行 jax_lr

fmshglm commented 2 months ago

是指在同一个容器里运行命令: bazel run -c opt //examples/python/utils:nodectl -- up && bazel run -c opt //examples/python/ml/jax_lr:jax_lr

这样的话,我的程序停在执行开启节点后,后面的jax_lr要怎么才能执行呢 Target //examples/python/utils:nodectl up-to-date: bazel-bin/examples/python/utils/nodectl INFO: Elapsed time: 9.904s, Critical Path: 0.83s INFO: 1 process: 1 internal. INFO: Build completed successfully, 1 total action INFO: Running command line: bazel-bin/examples/python/utils/nodectl up [2024-09-13 02:27:59,510] [ForkServerProcess-1] Starting grpc server at 127.0.0.1:64320 [2024-09-13 02:27:59,521] [ForkServerProcess-2] Starting grpc server at 127.0.0.1:64321

tpppppub commented 2 months ago

让第一个程序后台运行或者你开两个 terminal 进容器分别执行都行

fmshglm commented 2 months ago

谢谢,w, b = run_on_spu(x, y)可以执行了,但是w, b = run_on_spu(x, y, True)会报错 Run on SPU with cache

Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 213, in w, b = run_on_spu(x, y, True) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 177, in run_on_spu W, b = train(x1, x2, y) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 670, in call executable, args_flat, out_tree = self._compile_jax_func( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 760, in _compile_jax_func executable, output = spu_fe.compile( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 236, in compile ir_text, output = _jax_compilation( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 139, in _jax_compilation cfn, output = jax.xla_computation( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 172, in train return lr.fit_manual_grad(x, y, use_cache) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 99, in fit_manual_grad feature = spu.experimental.make_cached_var(feature) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/experimental/make_cached_var_impl.py", line 30, in make_cached_var return _make_cached_var_prim.bind(input) jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: custom_call() got an unexpected keyword argument 'result_types'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 213, in w, b = run_on_spu(x, y, True) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 177, in run_on_spu W, b = train(x1, x2, y) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 670, in call executable, args_flat, out_tree = self._compile_jax_func( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 760, in _compile_jax_func executable, output = spu_fe.compile( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 236, in compile ir_text, output = _jax_compilation( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 139, in _jax_compilation cfn, output = jax.xla_computation( File "/root/miniconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/root/miniconda3/lib/python3.10/site-packages/jax/_src/api.py", line 555, in computation_maker lowering_result = mlir.lower_jaxpr_to_module( File "/root/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 699, in lower_jaxpr_to_module lower_jaxpr_to_fun( File "/root/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1030, in lower_jaxpr_to_fun out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), File "/root/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1177, in jaxpr_subcomp ans = rule(rule_ctx, map(_unwrap_singleton_ir_values, in_nodes), File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/experimental/make_cached_var_impl.py", line 45, in _make_cached_var_lowering return custom_call( jax._src.traceback_util.UnfilteredStackTrace: TypeError: custom_call() got an unexpected keyword argument 'result_types'

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 213, in w, b = run_on_spu(x, y, True) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/examples/python/ml/jax_lr/jax_lr.py", line 177, in run_on_spu W, b = train(x1, x2, y) File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 670, in call executable, args_flat, out_tree = self._compile_jax_func( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/distributed_impl.py", line 760, in _compile_jax_func executable, output = spu_fe.compile( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 236, in compile ir_text, output = _jax_compilation( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/utils/frontend.py", line 139, in _jax_compilation cfn, output = jax.xla_computation( File "/root/.cache/bazel/_bazel_root/eceb46742416a02f6a0f8d92bc74468c/execroot/spulib/bazel-out/k8-opt/bin/examples/python/ml/jax_lr/jax_lr.runfiles/spulib/spu/experimental/make_cached_var_impl.py", line 45, in _make_cached_var_lowering return custom_call( TypeError: custom_call() got an unexpected keyword argument 'result_types'

tpppppub commented 2 months ago

运行 jax_lr 不需 tranformers[flax],后者的 jax/jaxlib 依赖版本过低。pip install -U jax jaxlib 更新下 jax/jaxlib 版本后你再试试

fmshglm commented 2 months ago

可以执行了,谢谢~