jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Orthogonal Initializer raises gpusolverDnCreate(&handle) failed: cuSolver internal error #23616

Open brianorbrain opened 2 months ago

brianorbrain commented 2 months ago

Description

I am having issues initializing a Flax.linen neural network when running with GPU support. I have narrowed it down to the flax.linen.initializers.orthogonal. Running the below code will result in a: RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

However running the code in another venv with only CPU support it runs just fine. And secondly running it without the orthogonal kernel initializer it runs just fine. The jax is installed using pip install -U "jax[cuda12]"

I have attached a minimal example that will raise the issue.


import os

os.environ['JAX_TRACEBACK_FILTERING'] = 'off'

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal

class SingleLayer(nn.Module):
    @nn.compact
    def __call__(self, x):
        layer = nn.Dense(64, kernel_init=orthogonal())(x)
        return layer

network = SingleLayer()
init_x = jnp.zeros(128)
network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
print(network_params)

/home/brain/Tensor/JaxRL/.venv/bin/python /home/brain/Tensor/JaxRL/flax_lax.py 
Traceback (most recent call last):
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
    network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
    layer = nn.Dense(64, kernel_init=orthogonal())(x)
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
    kernel = self.param(
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 1300, in qr
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
    network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2442, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2294, in init_with_output
    return init_with_output(
           ^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1144, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1108, in wrapper
    y = fn(root, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 3081, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
    layer = nn.Dense(64, kernel_init=orthogonal())(x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
    kernel = self.param(
             ^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1867, in param
    v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 997, in param
    value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
           ^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1643, in _pjit_call_impl_python
    compiled = _resolve_and_lower(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1610, in _resolve_and_lower
    lowered = _pjit_lower(
              ^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1748, in _pjit_lower
    return _pjit_lower_cached(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1769, in _pjit_lower_cached
    return pxla.lower_sharding_computation(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2230, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1950, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1132, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1590, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2036, in f_lowered
    out, tokens = jaxpr_subcomp(
                  ^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/lax/linalg.py", line 1757, in _geqrf_cpu_gpu_lowering
    a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib64/python3.12/site-packages/jaxlib/gpu_solver.py", line 164, in _geqrf_hlo
    lwork, opaque = gpu_solver.build_geqrf_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

Thanks, Brian

System info (python version, jaxlib version, accelerator, etc.)

jax.print_environment_info()
jax:    0.4.32
jaxlib: 0.4.32
numpy:  2.1.1
python: 3.12.5 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='programming-desktop', release='6.10.8-200.fc40.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Sep  4 21:41:11 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu Sep 12 22:41:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 59%   57C    P0            180W /  390W |    1759MiB /  24576MiB |     34%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2921      G   /usr/libexec/Xorg                             456MiB |
|    0   N/A  N/A      3172    C+G   ...libexec/gnome-remote-desktop-daemon        258MiB |
|    0   N/A  N/A      3255      G   /usr/bin/gnome-shell                           76MiB |
|    0   N/A  N/A      4696      G   /usr/bin/nautilus                              24MiB |
|    0   N/A  N/A      4978      G   /usr/lib64/firefox/firefox                    190MiB |
|    0   N/A  N/A     64410      C   ...ensor/Sin_PPO_test/.venv/bin/python        366MiB |
|    0   N/A  N/A     98376      G   ...erProcess --variations-seed-version         16MiB |
|    0   N/A  N/A    111528      C   ...brain/Tensor/JaxRL/.venv/bin/python        256MiB |
+-----------------------------------------------------------------------------------------+
nvidia-cublas-cu12==12.6.1.4
nvidia-cuda-cupti-cu12==12.6.68
nvidia-cuda-nvcc-cu12==12.6.68
nvidia-cuda-runtime-cu12==12.6.68
nvidia-cudnn-cu12==9.4.0.58
nvidia-cufft-cu12==11.2.6.59
nvidia-cusolver-cu12==11.6.4.69
nvidia-cusparse-cu12==12.5.3.3
nvidia-nccl-cu12==2.23.4
nvidia-nvjitlink-cu12==12.6.68
dfm commented 2 months ago

Thanks for the report. I'm not too sure what the issue is here, but I'm happy to help dig into it. 2 requests:

  1. Would you be able to try to put together an (even!) smaller example that doesn't depend on flax? I.e. just calling QR on an array you construct yourself to make it easier to reproduce.
  2. Can you try downgrading to jax/jaxlib 0.4.31? v0.4.32 got yanked because of TPU issues (unrelated to this conversation!), but it'll be easier for me to try with v0.4.31 if you still see the issue there.
hawkinsp commented 2 months ago

If I had to guess, you're probably running out of GPU memory. Try lowering the preallocation fraction here: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html ?

brianorbrain commented 2 months ago

I have an even smaller example. Pulled it right from the documentation. https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.orthogonal.html#jax.nn.initializers.orthogonal

I am running this on jax 0.4.31 and still have the exact same issue. I am even setting the environment variable to restrict the preallocation to 0.25 and am still having the same issue.

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.25'
import jax, jax.numpy as jnp

initializer = jax.nn.initializers.orthogonal()

initializer(jax.random.key(42), (2, 3), jnp.float32)
  File "/home/brain/Tensor/jax/jax_orthogonal.py", line 8, in <module>
    initializer(jax.random.key(42), (2, 3), jnp.float32)
  File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
  File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 1291, in qr
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/brain/Tensor/jax/jax_orthogonal.py", line 8, in <module>
    initializer(jax.random.key(42), (2, 3), jnp.float32)
  File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
           ^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/jax/.venv/lib64/python3.11/site-packages/jaxlib/gpu_solver.py", line 156, in _geqrf_hlo
    lwork, opaque = gpu_solver.build_geqrf_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.1.1
python: 3.11.9 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='programming-desktop', release='6.10.8-200.fc40.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Sep  4 21:41:11 UTC 2024', machine='x86_64')
$ nvidia-smi
Sun Sep 15 14:55:58 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 59%   50C    P8             58W /  390W |    2585MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2921      G   /usr/libexec/Xorg                             490MiB |
|    0   N/A  N/A      3172    C+G   ...libexec/gnome-remote-desktop-daemon        258MiB |
|    0   N/A  N/A      3255      G   /usr/bin/gnome-shell                          174MiB |
|    0   N/A  N/A      4696      G   /usr/bin/nautilus                              15MiB |
|    0   N/A  N/A     98376      G   ...erProcess --variations-seed-version         16MiB |
|    0   N/A  N/A    111528      C   ...brain/Tensor/JaxRL/.venv/bin/python        256MiB |
|    0   N/A  N/A    158963      G   ...local/share/Steam/ubuntu12_32/steam          6MiB |
|    0   N/A  N/A    159598      G   ./steamwebhelper                                5MiB |
|    0   N/A  N/A    176468      G   /usr/lib64/firefox/firefox                    176MiB |
|    0   N/A  N/A    301901      C   ...ensor/Sin_PPO_test/.venv/bin/python        810MiB |
|    0   N/A  N/A    304334      C   ...e/brain/Tensor/jax/.venv/bin/python        256MiB |
+-----------------------------------------------------------------------------------------+