jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized #24866

Open carlosgmartin opened 1 week ago

carlosgmartin commented 1 week ago

Description

I used

python3 -m pip install --upgrade "jax[cuda12]"

to install JAX on a GPU node, but am getting a CUDA_ERROR_SYSTEM_NOT_READY error:

(base) $ python3 -c "import jax; jax.numpy.array(0)"
2024-11-12 15:37:25.005059: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized
Traceback (most recent call last):
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 896, in backends
    backend = _init_backend(platform)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 982, in _init_backend
    backend = registration.factory()
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 674, in factory
    return xla_client.make_c_api_client(plugin_name, updated_options, None)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5426, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind
    operand = core.Primitive.bind(convert_element_type_p, operand,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/core.py", line 438, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/core.py", line 955, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Here's some additional output:

(base) $ echo $CUDA_VISIBLE_DEVICES
0
(base) $ echo $LD_LIBRARY_PATH

(base) $ nvidia-smi
Tue Nov 12 15:38:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:07:00.0 Off |                    0 |
| N/A   25C    P0             43W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

System info (python version, jaxlib version, accelerator, etc.)

2024-11-12 15:36:48.401160: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized
Traceback (most recent call last):
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 896, in backends
    backend = _init_backend(platform)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 982, in _init_backend
    backend = registration.factory()
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 674, in factory
    return xla_client.make_c_api_client(plugin_name, updated_options, None)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/environment_info.py", line 49, in print_environment_info
    device info: {xb.devices()[0].device_kind}-{xb.device_count()}, {xb.local_device_count()} local devices"
                  ^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 1094, in devices
    return get_backend(backend).devices()
           ^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 1028, in get_backend
    return _get_backend_uncached(platform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 1007, in _get_backend_uncached
    bs = backends()
         ^^^^^^^^^^
  File "/marvel/home/cgmartin/miniforge3/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 912, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
carlosgmartin commented 1 week ago

I also tried uninstalling all existing nvidia and jax-cuda packages before re-installing JAX:

$ python3 -m pip freeze --all | grep -e nvidia -e jax-cuda | xargs python3 -m pip uninstall -y jax jaxlib
...
$ conda list | awk '{ print $1 }' | grep -e nvidia -e jax-cuda | xargs conda remove -y jax jaxlib
...
$ mamba list | awk '{ print $1 }' | grep -e nvidia -e jax-cuda | xargs mamba remove -y jax jaxlib
...
$ python3 -m pip install --upgrade "jax[cuda12]"
...
$ echo $CUDA_VISIBLE_DEVICES
0,1,2,3,4,5,6,7
$ echo $LD_LIBRARY_PATH

$ python3 -c "import jax; jax.numpy.array(0)"
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

but still get the same error message.