Open jmSNU opened 1 month ago
Hi @jmSNU !
I think your issue is related to a broken installation of jax. I recommend you to follow these steps https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda and create a fresh conda environment. If that does not work, I recommend you to try building a docker image following steps here https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda . If that also does not work , I can provide you with the docker image we used for experiments.
Thank you for your reply.
However, I tried to follow both ways but failed. For conda, jax was a problem, while for docker pip installation failed.
Hello,
I would like to report a serious issue I have encountered.
To provide context, my local environment is configured as follows:
GPU: RTX 4090 CUDA version: 12.5 cuDNN version: 8.9 OS: Ubuntu 22.04 When I attempted to run the example provided in a virtual Conda environment, I encountered the following error.
Traceback (most recent call last): File "recall2imagine/train.py", line 241, in
main()
File "recall2imagine/train.py", line 63, in main
agent = agt.Agent(env.obs_space, env.act_space, step, config)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in init
super().init(agent_cls, *args, kwargs)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in init
self.varibs = self._init_varibs(obs_space, act_space)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
state, varibs = self._init_train(varibs, rng, data['is_first'])
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
created = init(statics, rng, *args, *kw)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(args, kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, *params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(args, kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
Additionally, I observed that even a simple JAX operation (e.g., jnp.ones((3,))) results in the same error within the Conda environment. Based on my observations, I believe this issue may be related to compatibility between JAX, CUDA, and cuDNN.
I hope this matter can be resolved soon. Thank you for your attention to this issue.