jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.05k stars 2.75k forks source link

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11 #21621

Open PhilipVinc opened 3 months ago

PhilipVinc commented 3 months ago

Description

I am consistently getting an error out of a compilcated code

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11, output: : If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.

after having installed jax/jaxlib with on a clean environment.

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I also made sure that in my LD_LIBRARY_PATH nothing is set.

Is there some way to debug this in any way?

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cholesky-gpu02', release='3.10.0-1062.el7.x86_64', version='#1 SMP Wed Aug 7 18:08:02 UTC 2019', machine='x86_64')

$ nvidia-smi
Tue Jun  4 11:20:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla V100-SXM2-32GB           On  | 00000000:1A:00.0 Off |                    0 |
| N/A   36C    P0              55W / 300W |    311MiB / 32768MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     42323      C   python                                      308MiB |
+---------------------------------------------------------------------------------------+

(the Nvidia SMI that is being picked up is from the cluster installation, but cuda is not in my path

(myenv2) [filippo.vicentini@cholesky-gpu02 test2]$ which nvcc
/usr/bin/which: no nvcc in (/mnt/beegfs/softs/opt/gcc_10.2.0/openmpi/4.1.4/bin:/mnt/beegfs/softs/opt/core/gcc/10.2.0/bin:/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/myenv2/bin:/opt/MegaRAID/perccli:/mnt/beegfs/home/CPHT/filippo.vicentini/.vscode-server/cli/servers/Stable-dc96b837cf6bb4af9cd736aa3af08cf8279f7685/server/bin/remote-cli:/mnt/beegfs/home/CPHT/filippo.vicentini/.cargo/bin:/mnt/beegfs/softs/opt/core/mambaforge/22.11.1-4/condabin:/usr/lib64/qt-3.3/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/opt/dell/srvadmin/bin:/mnt/beegfs/home/CPHT/filippo.vicentini/.local/bin:/mnt/beegfs/home/CPHT/filippo.vicentini/bin:/opt/ibutils/bin)
(myenv2) [filippo.vicentini@cholesky-gpu02 test2]$ which ptxas
/usr/bin/which: no ptxas in (/mnt/beegfs/softs/opt/gcc_10.2.0/openmpi/4.1.4/bin:/mnt/beegfs/softs/opt/core/gcc/10.2.0/bin:/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/myenv2/bin:/opt/MegaRAID/perccli:/mnt/beegfs/home/CPHT/filippo.vicentini/.vscode-server/cli/servers/Stable-dc96b837cf6bb4af9cd736aa3af08cf8279f7685/server/bin/remote-cli:/mnt/beegfs/home/CPHT/filippo.vicentini/.cargo/bin:/mnt/beegfs/softs/opt/core/mambaforge/22.11.1-4/condabin:/usr/lib64/qt-3.3/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/opt/dell/srvadmin/bin:/mnt/beegfs/home/CPHT/filippo.vicentini/.local/bin:/mnt/beegfs/home/CPHT/filippo.vicentini/bin:/opt/ibutils/bin)
PhilipVinc commented 3 months ago

This is the error I get. I can also share a reproducer if wanted.

Traceback (most recent call last):
  File "/mnt/beegfs/project/ndqm/test_luca/time_evolution.py", line 570, in <module>
    obs_dict = solve_variational_evolution(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/project/ndqm/test_luca/time_evolution.py", line 299, in solve_variational_evolution
    step_function = integration_algorithm(dt, H, exp_x)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/project/ndqm/test_luca/time_evolution.py", line 363, in step_explicit_O2
    exp_z = nkj.operations.get_apply_exp_diagH(Hd)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/netket_pro/jumps/operations/exact_ops_on_FrozenExtendedNet.py", line 41, in get_apply_exp_diagH
    i, j = ij.T
           ^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 630, in transpose
    return lax.transpose(a, axes_)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 986, in transpose
    return transpose_p.bind(operand, permutation=permutation)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/core.py", line 387, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/filippo.vicentini/mambaforge/envs/ENV_NAME/lib/python3.11/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 11, output: : If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
/mnt/beegfs/project/ndqm/test_luca/time_evolution.py:542: UserWarning: Data has no positive values, and therefore cannot be log-scaled.
  ax.set_yscale("log")
/mnt/beegfs/project/ndqm/test_luca/time_evolution.py:559: UserWarning: Data has no positive values, and therefore cannot be log-scaled.
  ax[0,i].set_yscale("log")
mattjj commented 3 months ago

Thanks for raising this. Yes, can you share a reproducer?

The error talks about filesystem issues ("If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided"). Could there be a permissions issue?