Open brianorbrain opened 2 months ago
Thanks for the report. I'm not too sure what the issue is here, but I'm happy to help dig into it. 2 requests:
If I had to guess, you're probably running out of GPU memory. Try lowering the preallocation fraction here: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html ?
I have an even smaller example. Pulled it right from the documentation. https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.orthogonal.html#jax.nn.initializers.orthogonal
I am running this on jax 0.4.31 and still have the exact same issue. I am even setting the environment variable to restrict the preallocation to 0.25 and am still having the same issue.
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.25'
import jax, jax.numpy as jnp
initializer = jax.nn.initializers.orthogonal()
initializer(jax.random.key(42), (2, 3), jnp.float32)
File "/home/brain/Tensor/jax/jax_orthogonal.py", line 8, in <module>
initializer(jax.random.key(42), (2, 3), jnp.float32)
File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 611, in init
Q, R = jnp.linalg.qr(A)
File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 1291, in qr
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/brain/Tensor/jax/jax_orthogonal.py", line 8, in <module>
initializer(jax.random.key(42), (2, 3), jnp.float32)
File "/home/brain/Tensor/jax/.venv/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 611, in init
Q, R = jnp.linalg.qr(A)
^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/jax/.venv/lib64/python3.11/site-packages/jaxlib/gpu_solver.py", line 156, in _geqrf_hlo
lwork, opaque = gpu_solver.build_geqrf_descriptor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.1.1
python: 3.11.9 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='programming-desktop', release='6.10.8-200.fc40.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Sep 4 21:41:11 UTC 2024', machine='x86_64')
$ nvidia-smi
Sun Sep 15 14:55:58 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 On | N/A |
| 59% 50C P8 58W / 390W | 2585MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2921 G /usr/libexec/Xorg 490MiB |
| 0 N/A N/A 3172 C+G ...libexec/gnome-remote-desktop-daemon 258MiB |
| 0 N/A N/A 3255 G /usr/bin/gnome-shell 174MiB |
| 0 N/A N/A 4696 G /usr/bin/nautilus 15MiB |
| 0 N/A N/A 98376 G ...erProcess --variations-seed-version 16MiB |
| 0 N/A N/A 111528 C ...brain/Tensor/JaxRL/.venv/bin/python 256MiB |
| 0 N/A N/A 158963 G ...local/share/Steam/ubuntu12_32/steam 6MiB |
| 0 N/A N/A 159598 G ./steamwebhelper 5MiB |
| 0 N/A N/A 176468 G /usr/lib64/firefox/firefox 176MiB |
| 0 N/A N/A 301901 C ...ensor/Sin_PPO_test/.venv/bin/python 810MiB |
| 0 N/A N/A 304334 C ...e/brain/Tensor/jax/.venv/bin/python 256MiB |
+-----------------------------------------------------------------------------------------+
Description
I am having issues initializing a Flax.linen neural network when running with GPU support. I have narrowed it down to the flax.linen.initializers.orthogonal. Running the below code will result in a:
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
However running the code in another venv with only CPU support it runs just fine. And secondly running it without the orthogonal kernel initializer it runs just fine. The jax is installed using
pip install -U "jax[cuda12]"
I have attached a minimal example that will raise the issue.
Thanks, Brian
System info (python version, jaxlib version, accelerator, etc.)