jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found #25075

Open PhilipVinc opened 5 days ago

PhilipVinc commented 5 days ago

Description

Installing a fresh copy of jax 0.4.35 with

pip install "jax[cuda12]==0.4.35" "nvidia-cudnn-cu12<9.4"

(this installs nvidia-cudnn-cu12==9.3.0.75) leads to a broken installation due to the following error

the same error also appears if you remove the constraint nvidia-cudnn-cu12<9.4and install nvidia-cudnn-cu12==9.5.1.17

the last version of jax working correctly is 0.4.33

>>> jax.numpy.ones((3,4))
E1123 19:13:25.034265  814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1123 19:13:25.071507  814362 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5949, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1615, in full
    fill_value = _convert_element_type(fill_value, dtype, weak_type)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind
    operand = core.Primitive.bind(convert_element_type_p, operand,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/leonardo/home/userexternal/fvicenti/test1/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
>>> 

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.0.2
python: 3.12.7 (main, Oct 16 2024, 04:37:19) [Clang 18.1.8 ]
device info: NVIDIA A100-SXM-64GB-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='lrdn3434.leonardo.local', release='4.18.0-425.19.2.el8_7.x86_64', version='#1 SMP Fri Mar 17 01:52:38 EDT 2023', machine='x86_64')

$ nvidia-smi
Sat Nov 23 19:13:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM-64GB            On | 00000000:1D:00.0 Off |                    0 |
| N/A   44C    P0               80W / 475W|    477MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM-64GB            On | 00000000:56:00.0 Off |                    0 |
| N/A   44C    P0               76W / 473W|    477MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM-64GB            On | 00000000:8F:00.0 Off |                    0 |
| N/A   44C    P0               73W / 453W|    477MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM-64GB            On | 00000000:C8:00.0 Off |                    0 |
| N/A   43C    P0               74W / 453W|    477MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    814362      C   ...al/fvicenti/test1/.venv/bin/python3      474MiB |
|    1   N/A  N/A    814362      C   ...al/fvicenti/test1/.venv/bin/python3      474MiB |
|    2   N/A  N/A    814362      C   ...al/fvicenti/test1/.venv/bin/python3      474MiB |
|    3   N/A  N/A    814362      C   ...al/fvicenti/test1/.venv/bin/python3      474MiB |
+---------------------------------------------------------------------------------------+