I encountered an issue when running JAX with JIT compilation and single vmap on a GPU. The following MWE demonstrates the problem:
import jax
import jax.numpy as jnp
def fn(x):
R1 = jnp.array([[x[0], 0, 0],
[0, 1, 0],
[0, 0, x[0]]]) # Changing x[0] to x[2] here resolves the issue...
# Another matrix
R2 = jnp.array([[x[0], 0, 0],
[0, x[1], 0],
[0, 0, x[2]]])
# R2 = jnp.diag(x) # Using jnp.diag resolves the issue...
H = jnp.eye(4)
H = H.at[:3, :3].set(R2.T) # Removing .T resolves the issue
pos = H @ jnp.concatenate([x, jnp.array([1.0])])
# pos = H[:3, :3] @ x # Using this line resolves the issue...
return pos, R1 # Only returning either pos, or R resolves the issue...
gpu = jax.devices("gpu")[0]
cpu = jax.devices("cpu")[0]
N = 5
x_v = jnp.zeros((N, 3))
fn_v = jax.vmap(fn)
fn_jv_cpu = jax.jit(jax.vmap(fn), device=cpu)
fn_jv_gpu = jax.jit(jax.vmap(fn), device=gpu)
M = 4 # changing M=5 resolves the issue
x_vv = jnp.zeros((M, N, 3))
fn_jvv_gpu = jax.jit(jax.vmap(jax.vmap(fn)), device=gpu)
res_vv_gpu = fn_jvv_gpu(x_vv)
print("Jit (GPU), double vmap: SUCCESS")
res_v = fn_v(x_v)
print("No jit, single vmap: SUCCESS")
res_v_cpu = fn_jv_cpu(x_v)
print("Jit (CPU), single vmap: SUCCESS")
res_v_gpu = fn_jv_gpu(x_v) # Fails here...
print("Jit (GPU), single vmap: SUCCESS")
Error Message:
2024-08-05 09:03:10.802291: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Jit (GPU), double vmap: SUCCESS
No jit, single vmap: SUCCESS
Jit (CPU), single vmap: SUCCESS
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/r2ci/rex/scratch/scratch_bug.py", line 40, in <module>
res_v_gpu = fn_jv_gpu(x_v) # Fails here...
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Binary op with incompatible shapes: f32[4,5,4] and f32[5,4,4].
Process finished with exit code 1
Steps to Reproduce:
Run the provided script on a system with JAX and a GPU.
Observe the error when fn_jv_gpu(x_v) is called.
Additional Information:
Changing x[0] to x[2] in R1 resolves the issue.
Using jnp.diag for R2 resolves the issue.
Removing the .T in H = H.at[:3, :3].set(R2.T) resolves the issue.
Only returning either pos or R1 resolves the issue.
Setting M=5 instead of M=4 resolves the issue.
Code runs without errors using Jax version 0.4.13.
Any help or guidance on resolving this issue would be greatly appreciated. Thank you!
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 1.26.4
python: 3.9.19 (main, Apr 6 2024, 17:57:55) [GCC 9.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='r2ci-Alienware-m15-R4', release='5.15.0-117-generic', version='#127~20.04.1-Ubuntu SMP Thu Jul 11 15:36:12 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Aug 5 09:01:31 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3070 ... Off | 00000000:01:00.0 On | N/A |
| N/A 58C P0 34W / 125W | 785MiB / 8192MiB | 3% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1751 G /usr/lib/xorg/Xorg 139MiB |
| 0 N/A N/A 2113 G /usr/lib/xorg/Xorg 234MiB |
| 0 N/A N/A 2245 G /usr/bin/gnome-shell 78MiB |
| 0 N/A N/A 10309 G /usr/lib/firefox/firefox 165MiB |
| 0 N/A N/A 11171 C /home/r2ci/rex/.venv/bin/python 138MiB |
+---------------------------------------------------------------------------------------+
Description
I encountered an issue when running JAX with JIT compilation and single vmap on a GPU. The following MWE demonstrates the problem:
Error Message:
Steps to Reproduce:
fn_jv_gpu(x_v)
is called.Additional Information:
x[0]
tox[2]
inR1
resolves the issue.jnp.diag
forR2
resolves the issue..T
inH = H.at[:3, :3].set(R2.T)
resolves the issue.pos
orR1
resolves the issue.M=5
instead ofM=4
resolves the issue.0.4.13
.Any help or guidance on resolving this issue would be greatly appreciated. Thank you!
System info (python version, jaxlib version, accelerator, etc.)