Hi, I'm trying to reproduce your fine tuning results but have encountered a lot of dependencies issues.
I'm running the code on a Nvidia A4000 and with cuda11.0 and cudnn 8.1.1 installed. I used the jax version 0.3.5 as you suggested.
Right now I'm getting this error:
`UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
I0221 16:06:05.997925 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0221 16:06:06.002030 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
2023-02-21 16:06:06.202993: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] WARNING You are using ptxas 11.0.194, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2023-02-21 16:06:06.205409: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2023-02-21 16:06:06.205439: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas
2023-02-21 16:06:06.208047: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2023-02-21 16:06:06.208115: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
2023-02-21 16:06:06.209365: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2023-02-21 16:06:06.209427: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
Traceback (most recent call last):
File "experiments/image_classification/run_experiment.py", line 40, in
app.run(functools.partial(platform.main, experiment.Experiment))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 484, in inner_wrapper
return f(*args, kwargs)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/platform.py", line 148, in main
train.evaluate(experiment_class, config,
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 620, in inner_wrapper
return fn(*args, *kwargs)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/train.py", line 155, in evaluate
eval_rng = jax.random.PRNGKey(config.random_seed)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/random.py", line 125, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 232, in seed_with_impl
return PRNGKeyArray(impl, impl.seed(seed))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 272, in threefry_seed
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 445, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 288, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 291, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 613, in process_primitive
return primitive.impl(tracers, params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 95, in apply_primitive
return compiled_fun(args)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in
return lambda args, *kw: compiled(args, **kw)[0]
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 439, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: INTERNAL: Could not find the corresponding function`
wonder if you've encountered this issue before. If so, how did you fix it?
Thanks in advance!!
Hi, I'm trying to reproduce your fine tuning results but have encountered a lot of dependencies issues.
I'm running the code on a Nvidia A4000 and with cuda11.0 and cudnn 8.1.1 installed. I used the jax version 0.3.5 as you suggested.
Right now I'm getting this error:
`UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( I0221 16:06:05.997925 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0221 16:06:06.002030 140161066599616 xla_bridge.py:260] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. 2023-02-21 16:06:06.202993: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] WARNING You are using ptxas 11.0.194, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient. 2023-02-21 16:06:06.205409: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6 2023-02-21 16:06:06.205439: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas 2023-02-21 16:06:06.208047: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found 2023-02-21 16:06:06.208115: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function 2023-02-21 16:06:06.209365: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel "shift_right_logical_3" from module: CUDA_ERROR_NOT_FOUND: named symbol not found 2023-02-21 16:06:06.209427: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function Traceback (most recent call last): File "experiments/image_classification/run_experiment.py", line 40, in
app.run(functools.partial(platform.main, experiment.Experiment))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 484, in inner_wrapper
return f(*args, kwargs)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/platform.py", line 148, in main
train.evaluate(experiment_class, config,
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/utils.py", line 620, in inner_wrapper
return fn(*args, *kwargs)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jaxline/train.py", line 155, in evaluate
eval_rng = jax.random.PRNGKey(config.random_seed)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/random.py", line 125, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 232, in seed_with_impl
return PRNGKeyArray(impl, impl.seed(seed))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/prng.py", line 272, in threefry_seed
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 445, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 288, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 291, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/core.py", line 613, in process_primitive
return primitive.impl(tracers, params)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 95, in apply_primitive
return compiled_fun(args)
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in
return lambda args, *kw: compiled(args, **kw)[0]
File "/home/mine01/Desktop/code/jax_ve/lib/python3.8/site-packages/jax/_src/dispatch.py", line 439, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: INTERNAL: Could not find the corresponding function`
wonder if you've encountered this issue before. If so, how did you fix it? Thanks in advance!!