Closed Liuxueyi closed 6 months ago
Hi, @Liuxueyi !
Could you please give more context on your setup? OS, python version, cuda version and how it was installed? It looks like you have an issue with conflicting CUDA/CUDNN installation paths.
Also can you show the output of
import jax
m = jax.numpy.array([1,])
m@m
? (at least say whether it runs or fails)
Finally, is using docker or singularity an option for you?
Thank you very much for your quick reply! Here is my context version: Ubuntu 20.04, python 3.8, cuda 11.8. The cuda and cudnn are installed with jaxlib. I have built the docker with your instruction and try to run the code in docker. I tried your example code, and the error is:
>>> import jax
>>> m = jax.numpy.array([1,])
>>> m@m
2024-03-27 00:14:27.594128: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 258, in deferring_binary_op
return binary_op(*args)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 258, in deferring_binary_op
return binary_op(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
When I use docker to run the code, there also is an error about GPU:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Traceback (most recent call last): File "recall2imagine/train.py", line 241, in <module> main() File "recall2imagine/train.py", line 63, in main agent = agt.Agent(env.obs_space, env.act_space, step, config) File "/code/recall2imagine/jaxagent.py", line 20, in __init__ super().__init__(agent_cls, *args, **kwargs) File "/code/recall2imagine/jaxagent.py", line 37, in __init__ available = jax.devices(self.config.platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 758, in devices return get_backend(backend).devices() File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 692, in get_backend return _get_backend_uncached(platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 675, in _get_backend_uncached platform = canonicalize_platform(platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 548, in canonicalize_platform raise RuntimeError(f"Unknown backend: '{platform}' requested, but no " RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu Gracefully stopping...
Is there any error about the installation of cuda in the base environment?
Thank you for your patience.
@Liuxueyi did you create the conda environment from scratch or did you use an existing one? Also, if it's a new environment for just created for R2I, did you follow the installation instructions in our repo? This is the link: https://github.com/chandar-lab/Recall2Imagine?tab=readme-ov-file#conda
Yes, I just followed the instructions and created a new conda environment for R2I. I think there may be any errors of cuda in my base environment. I will check it. And if I have some progress, I will share it. Thank you very much!
Sure. Will be happy to help.
@Liuxueyi , I am closing the issue for now since it's not related to our codebase. Please reopen it if you face an issue related to the codebase.
When I add the instruction --gpus all
, the error with docker is resolved.
I used the conda env and have installed jaxlib 0.4.13 with cuda11.cudnn86. When I run the code, the error occurred: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.