google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

failed to get PTX kernel "shift_right_logical_3" from module #8

Closed heilrahc closed 1 year ago

heilrahc commented 1 year ago

Hi, I'm trying to reproduce your fine tuning results but have encountered a lot of dependencies issues.

I'm running the code on a Nvidia A4000 and with cuda11.0 and cudnn 8.1.1 installed. I used the jax version 0.3.5 as you suggested.

Right now I'm getting this error:

`UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( I0221 16:06:05.997925 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0221 16:06:06.002030 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. 2023-02-21 16:06:06.202993: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] WARNING You are using ptxas 11.0.194, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient. 2023-02-21 16:06:06.205409: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6 2023-02-21 16:06:06.205439: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas 2023-02-21 16:06:06.208047: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found 2023-02-21 16:06:06.208115: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function 2023-02-21 16:06:06.209365: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found 2023-02-21 16:06:06.209427: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function Traceback (most recent call last): File "experiments/image_classification/run_experiment.py", line 40, in app.run(functools.partial(platform.main, experiment.Experiment)) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 312, in run _run_main(main, args) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main sys.exit(main(argv)) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 484, in inner_wrapper return f(*args, kwargs) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/platform.py", line 148, in main train.evaluate(experiment_class, config, File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 620, in inner_wrapper return fn(*args, *kwargs) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/train.py", line 155, in evaluate eval_rng = jax.random.PRNGKey(config.random_seed) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/random.py", line 125, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 232, in seed_with_impl return PRNGKeyArray(impl, impl.seed(seed)) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 272, in threefry_seed lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32))) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 445, in shift_right_logical return shift_right_logical_p.bind(x, y) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 288, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 291, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 613, in process_primitive return primitive.impl(tracers, params) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 95, in apply_primitive return compiled_fun(args) File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in return lambda args, *kw: compiled(args, **kw)[0] File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 439, in _execute_compiled out_bufs = compiled.execute(input_bufs) RuntimeError: INTERNAL: Could not find the corresponding function`

wonder if you've encountered this issue before. If so, how did you fix it? Thanks in advance!!