chandar-lab / Recall2Imagine

Recall to Imagine, a model-based RL algorithm with superhuman memory. Oral (1.2%) @ ICLR 2024
https://recall2imagine.github.io/
MIT License
50 stars 5 forks source link

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. #12

Open jmSNU opened 1 month ago

jmSNU commented 1 month ago

Hello,

I would like to report a serious issue I have encountered.

To provide context, my local environment is configured as follows:

GPU: RTX 4090 CUDA version: 12.5 cuDNN version: 8.9 OS: Ubuntu 22.04 When I attempted to run the example provided in a virtual Conda environment, I encountered the following error.

Traceback (most recent call last): File "recall2imagine/train.py", line 241, in main() File "recall2imagine/train.py", line 63, in main agent = agt.Agent(env.obs_space, env.act_space, step, config) File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in init super().init(agent_cls, *args, kwargs) File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in init self.varibs = self._init_varibs(obs_space, act_space) File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs state, varibs = self._init_train(varibs, rng, data['is_first']) File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper created = init(statics, rng, *args, *kw) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, params) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, *params) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python compiled = _pjit_lower( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.

Additionally, I observed that even a simple JAX operation (e.g., jnp.ones((3,))) results in the same error within the Conda environment. Based on my observations, I believe this issue may be related to compatibility between JAX, CUDA, and cuDNN.

I hope this matter can be resolved soon. Thank you for your attention to this issue.

artemZholus commented 1 month ago

Hi @jmSNU !

I think your issue is related to a broken installation of jax. I recommend you to follow these steps https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda and create a fresh conda environment. If that does not work, I recommend you to try building a docker image following steps here https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda . If that also does not work , I can provide you with the docker image we used for experiments.

jmSNU commented 1 month ago

Thank you for your reply.

However, I tried to follow both ways but failed. For conda, jax was a problem, while for docker pip installation failed.