jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

JIT Compilation Error with Single Vmap on GPU #22865

Open bheijden opened 3 months ago

bheijden commented 3 months ago

Description

I encountered an issue when running JAX with JIT compilation and single vmap on a GPU. The following MWE demonstrates the problem:

import jax
import jax.numpy as jnp

def fn(x):
    R1 = jnp.array([[x[0], 0, 0],
                    [0, 1, 0],
                    [0, 0, x[0]]])  # Changing x[0] to x[2] here resolves the issue...

    # Another matrix
    R2 = jnp.array([[x[0], 0, 0],
                   [0, x[1], 0],
                   [0, 0, x[2]]])
    # R2 = jnp.diag(x)  # Using jnp.diag resolves the issue...
    H = jnp.eye(4)
    H = H.at[:3, :3].set(R2.T)  # Removing .T resolves the issue
    pos = H @ jnp.concatenate([x, jnp.array([1.0])])
    # pos = H[:3, :3] @ x  # Using this line resolves the issue...
    return pos, R1  # Only returning either pos, or R resolves the issue...

gpu = jax.devices("gpu")[0]
cpu = jax.devices("cpu")[0]

N = 5
x_v = jnp.zeros((N, 3))
fn_v = jax.vmap(fn)
fn_jv_cpu = jax.jit(jax.vmap(fn), device=cpu)
fn_jv_gpu = jax.jit(jax.vmap(fn), device=gpu)

M = 4  # changing M=5 resolves the issue
x_vv = jnp.zeros((M, N, 3))
fn_jvv_gpu = jax.jit(jax.vmap(jax.vmap(fn)), device=gpu)

res_vv_gpu = fn_jvv_gpu(x_vv)
print("Jit (GPU), double vmap: SUCCESS")
res_v = fn_v(x_v)
print("No jit, single vmap: SUCCESS")
res_v_cpu = fn_jv_cpu(x_v)
print("Jit (CPU), single vmap: SUCCESS")
res_v_gpu = fn_jv_gpu(x_v)  # Fails here...
print("Jit (GPU), single vmap: SUCCESS")

Error Message:

2024-08-05 09:03:10.802291: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Jit (GPU), double vmap: SUCCESS
No jit, single vmap: SUCCESS
Jit (CPU), single vmap: SUCCESS
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/r2ci/rex/scratch/scratch_bug.py", line 40, in <module>
    res_v_gpu = fn_jv_gpu(x_v)  # Fails here...
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Binary op with incompatible shapes: f32[4,5,4] and f32[5,4,4].

Process finished with exit code 1

Steps to Reproduce:

  1. Run the provided script on a system with JAX and a GPU.
  2. Observe the error when fn_jv_gpu(x_v) is called.

Additional Information:

Any help or guidance on resolving this issue would be greatly appreciated. Thank you!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.9.19 (main, Apr  6 2024, 17:57:55)  [GCC 9.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='r2ci-Alienware-m15-R4', release='5.15.0-117-generic', version='#127~20.04.1-Ubuntu SMP Thu Jul 11 15:36:12 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Aug  5 09:01:31 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3070 ...    Off | 00000000:01:00.0  On |                  N/A |
| N/A   58C    P0              34W / 125W |    785MiB /  8192MiB |      3%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1751      G   /usr/lib/xorg/Xorg                          139MiB |
|    0   N/A  N/A      2113      G   /usr/lib/xorg/Xorg                          234MiB |
|    0   N/A  N/A      2245      G   /usr/bin/gnome-shell                         78MiB |
|    0   N/A  N/A     10309      G   /usr/lib/firefox/firefox                    165MiB |
|    0   N/A  N/A     11171      C   /home/r2ci/rex/.venv/bin/python             138MiB |
+---------------------------------------------------------------------------------------+
jakevdp commented 3 months ago

Thanks for the clear report! This looks like an XLA-GPU issue – I'll raise it in the appropriate channels.