chandar-lab / Recall2Imagine

Recall to Imagine, a model-based RL algorithm with superhuman memory. Oral (1.2%) @ ICLR 2024
https://recall2imagine.github.io/
MIT License
47 stars 5 forks source link

CuDNN library incompatible error #1

Closed Liuxueyi closed 6 months ago

Liuxueyi commented 6 months ago

I used the conda env and have installed jaxlib 0.4.13 with cuda11.cudnn86. When I run the code, the error occurred: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

artemZholus commented 6 months ago

Hi, @Liuxueyi !

Could you please give more context on your setup? OS, python version, cuda version and how it was installed? It looks like you have an issue with conflicting CUDA/CUDNN installation paths.

Also can you show the output of

import jax
m = jax.numpy.array([1,])
m@m

? (at least say whether it runs or fails)

Finally, is using docker or singularity an option for you?

Liuxueyi commented 6 months ago

Thank you very much for your quick reply! Here is my context version: Ubuntu 20.04, python 3.8, cuda 11.8. The cuda and cudnn are installed with jaxlib. I have built the docker with your instruction and try to run the code in docker. I tried your example code, and the error is:

>>> import jax
>>> m = jax.numpy.array([1,])
>>> m@m
2024-03-27 00:14:27.594128: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 258, in deferring_binary_op
    return binary_op(*args)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
    compiled = _pjit_lower(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lxy/anaconda3/envs/r2i/lib/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 258, in deferring_binary_op
    return binary_op(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Liuxueyi commented 6 months ago

When I use docker to run the code, there also is an error about GPU: No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Traceback (most recent call last): File "recall2imagine/train.py", line 241, in <module> main() File "recall2imagine/train.py", line 63, in main agent = agt.Agent(env.obs_space, env.act_space, step, config) File "/code/recall2imagine/jaxagent.py", line 20, in __init__ super().__init__(agent_cls, *args, **kwargs) File "/code/recall2imagine/jaxagent.py", line 37, in __init__ available = jax.devices(self.config.platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 758, in devices return get_backend(backend).devices() File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 692, in get_backend return _get_backend_uncached(platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 675, in _get_backend_uncached platform = canonicalize_platform(platform) File "/conda/miniconda/envs/py38/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 548, in canonicalize_platform raise RuntimeError(f"Unknown backend: '{platform}' requested, but no " RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu Gracefully stopping... Is there any error about the installation of cuda in the base environment? Thank you for your patience.

artemZholus commented 6 months ago

@Liuxueyi did you create the conda environment from scratch or did you use an existing one? Also, if it's a new environment for just created for R2I, did you follow the installation instructions in our repo? This is the link: https://github.com/chandar-lab/Recall2Imagine?tab=readme-ov-file#conda

Liuxueyi commented 6 months ago

Yes, I just followed the instructions and created a new conda environment for R2I. I think there may be any errors of cuda in my base environment. I will check it. And if I have some progress, I will share it. Thank you very much!

artemZholus commented 6 months ago

Sure. Will be happy to help.

artemZholus commented 6 months ago

@Liuxueyi , I am closing the issue for now since it's not related to our codebase. Please reopen it if you face an issue related to the codebase.

Liuxueyi commented 6 months ago

When I add the instruction --gpus all , the error with docker is resolved.