jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

pip installation: GPU (CUDA, installed via pip) not working for me with Brax #15508

Closed ViktorM closed 1 year ago

ViktorM commented 1 year ago

Description

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

jax installation using the command above is not working with the latest Brax: https://github.com/google/brax

For some reason it complains on runtime CuDNN version:

2023-04-09 22:29:02.848922: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.3.2 but source was compiled with: 8.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
  File "/home/viktorm/Projects/rl_games/runner.py", line 68, in <module>
    runner.run(args)
  File "/home/viktorm/Projects/rl_games/rl_games/torch_runner.py", line 119, in run
    self.run_train(args)
  File "/home/viktorm/Projects/rl_games/rl_games/torch_runner.py", line 97, in run_train
    agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params)
  File "/home/viktorm/Projects/rl_games/rl_games/common/object_factory.py", line 15, in create
    return builder(**kwargs)
  File "/home/viktorm/Projects/rl_games/rl_games/torch_runner.py", line 36, in <lambda>
    self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs))
  File "/home/viktorm/Projects/rl_games/rl_games/algos_torch/a2c_continuous.py", line 16, in __init__
    a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
  File "/home/viktorm/Projects/rl_games/rl_games/common/a2c_common.py", line 1056, in __init__
    A2CBase.__init__(self, base_name, params)
  File "/home/viktorm/Projects/rl_games/rl_games/common/a2c_common.py", line 119, in __init__
    self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
  File "/home/viktorm/Projects/rl_games/rl_games/common/vecenv.py", line 222, in create_vec_env
    return vecenv_config[vec_env_name](config_name, num_actors, **kwargs)
  File "/home/viktorm/Projects/rl_games/rl_games/common/vecenv.py", line 227, in <lambda>
    register('BRAX', lambda config_name, num_actors, **kwargs: BraxEnv(config_name, num_actors, **kwargs))
  File "/home/viktorm/Projects/rl_games/rl_games/envs/brax.py", line 43, in __init__
    self.env = envs.create(env_name=self.env_name, batch_size=self.num_envs, backend=self.sim_backend)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/envs/__init__.py", line 95, in create
    env = _envs[env_name](**kwargs)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/envs/ant.py", line 190, in __init__
    sys = mjcf.load(path)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/io/mjcf.py", line 512, in load
    return load_model(mj)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/io/mjcf.py", line 384, in load_model
    geoms = [
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/io/mjcf.py", line 385, in <listcomp>
    jax.tree_map(lambda *x: jp.stack(x), *g) for g in geom_groups.values()
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/brax/io/mjcf.py", line 385, in <lambda>
    jax.tree_map(lambda *x: jp.stack(x), *g) for g in geom_groups.values()
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1715, in stack
    new_arrays.append(expand_dims(a, axis))
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 872, in expand_dims
    return lax.expand_dims(a, axis)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1313, in expand_dims
    return broadcast_in_dim(array, result_shape, broadcast_dims)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/viktorm/anaconda3/envs/warp39/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What jax/jaxlib version are you using?

jax 0.4.8

Which accelerator(s) are you using?

RTX 4090

Additional system info

Python 3.9 Ubuntu 22.04

NVIDIA GPU info

+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 4090 On | 00000000:01:00.0 Off | Off | | 30% 44C P8 29W / 450W| 1535MiB / 24564MiB | 7% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 1285367 G ...viktorm/Downloads/Telegram/Telegram 72MiB | | 0 N/A N/A 1288565 G /usr/lib/xorg/Xorg 861MiB | | 0 N/A N/A 1288676 G /usr/bin/gnome-shell 99MiB | | 0 N/A N/A 1289827 G ...AAAAAAAACAAAAAAAAAA= --shared-files 86MiB | | 0 N/A N/A 1290862 G ...893162563,891623453882694089,131072 412MiB | +---------------------------------------------------------------------------------------+

hawkinsp commented 1 year ago

Well, the message Loaded runtime CuDNN library: 8.3.2 tells you the problem: JAX was trying to load CuDNN, but found a really old version (8.3) when it tried.

Usually this means that you have already loaded an older CuDNN into your process (e.g., importing PyTorch before importing JAX is one common way this can happen, since PyTorch usually bundles an older CuDNN). The other way it can happen is that an older CuDNN is first in your LD_LIBRARY_PATH, but I think that's unlikely to be the case here since you used the pip installation of JAX.

Try searching for other CuDNN installations (files name *cudnn*) on your system?

This is probably not something we can fix from the JAX end.

chaojiewang94 commented 1 year ago

same issue, it seems because the current version of pytorch does not support cudnn8.8 or higher version

hawkinsp commented 1 year ago

@chaojiewang94 You might do well to install a CPU-only version of PyTorch, if the goal is to use that in the context of a GPU-using JAX program.