(this installs nvidia-cudnn-cu12==9.3.0.75) leads to a broken installation due to the following error
the same error also appears if you remove the constraint nvidia-cudnn-cu12<9.4and install nvidia-cudnn-cu12==9.5.1.17
the last version of jax working correctly is 0.4.33
>>> jax.numpy.ones((3,4))
E1123 19:13:25.034265 814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1123 19:13:25.071507 814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5949, in ones
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1615, in full
fill_value = _convert_element_type(fill_value, dtype, weak_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind
operand = core.Primitive.bind(convert_element_type_p, operand,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
>>>
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ]
device info: NVIDIA A100-SXM-64GB-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='lrdn3434.leonardo.local', release='4.18.0-425.19.2.el8_7.x86_64', version='#1 SMP Fri Mar 17 01:52:38 EDT 2023', machine='x86_64')
$ nvidia-smi
Sat Nov 23 19:13:17 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM-64GB On | 00000000:1D:00.0 Off | 0 |
| N/A 44C P0 80W / 475W| 477MiB / 65536MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM-64GB On | 00000000:56:00.0 Off | 0 |
| N/A 44C P0 76W / 473W| 477MiB / 65536MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA A100-SXM-64GB On | 00000000:8F:00.0 Off | 0 |
| N/A 44C P0 73W / 453W| 477MiB / 65536MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA A100-SXM-64GB On | 00000000:C8:00.0 Off | 0 |
| N/A 43C P0 74W / 453W| 477MiB / 65536MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB |
| 1 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB |
| 2 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB |
| 3 N/A N/A 814362 C ...al/fvicenti/test1/.venv/bin/python3 474MiB |
+---------------------------------------------------------------------------------------+
Description
Installing a fresh copy of jax 0.4.35 with
(this installs
nvidia-cudnn-cu12==9.3.0.75
) leads to a broken installation due to the following errorthe same error also appears if you remove the constraint
nvidia-cudnn-cu12<9.4
and installnvidia-cudnn-cu12==9.5.1.17
the last version of jax working correctly is 0.4.33
System info (python version, jaxlib version, accelerator, etc.)