jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

CUDA `XlaRuntimeError` with MPI on `jax==0.4.31` #22995

Open MasterSkepticista opened 2 months ago

MasterSkepticista commented 2 months ago

Description

Hi,

jax.jit on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:

# error.py
# Run as: mpirun -n 8 python error.py

import os
from absl import logging
import jax, jax.numpy as jnp

logging.set_verbosity("info")
os.environ["no_proxy"] = "x.x.x.x"  # Internal use.
jax.distributed.initialize()

print("Hello from process %d holding %d device(s)" % (jax.process_index(), jax.local_device_count()))

def dot_product_attention(
    query: jnp.ndarray,
    key: jnp.ndarray,
    value: jnp.ndarray,
    *,
    dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

x = jnp.ones((1, 512, 8, 32), dtype=jnp.bfloat16)
f = lambda x: dot_product_attention(x, x, x)

print(jax.jit(f)(x))

The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in jax==0.4.30.

System info (python version, jaxlib version, accelerator, etc.)

Error log ```shell JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) Hello from process 3 holding 1 device(s) Hello from process 5 holding 1 device(s) Hello from process 1 holding 1 device(s) Hello from process 7 holding 1 device(s) Hello from process 0 holding 1 device(s) Hello from process 4 holding 1 device(s) Hello from process 6 holding 1 device(s) Hello from process 2 holding 1 device(s) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_1 and duration: -1ms jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in print(jax.jit(f)(x)) jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms -------------------------------------------------------------------------- Primary job terminated normally, but 1 process returned a non-zero exit code. Per user-direction, the job has been aborted. -------------------------------------------------------------------------- -------------------------------------------------------------------------- mpirun detected that one or more processes exited with non-zero status, thus causing the job to be terminated. The first process to do so was: Process name: [[53590,1],2] Exit code: 1 -------------------------------------------------------------------------- ``` System info: ```shell jax: 0.4.31 jaxlib: 0.4.31 numpy: 1.26.4 python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)] process_count: 1 platform: uname_result(system='Linux', node='ubuntu', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May 7 09:00:52 UTC 2', machine='x86_64') Truncated nvidia-smi info: NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 GPU: RTX A6000 ```
vfdev-5 commented 2 months ago

@MasterSkepticista the error is related with getting cuda:gemm_fusion_autotuning_results on shards and maybe related to https://github.com/openxla/xla/pull/13108 (cc @sergachev). To disable the autotuning and to make your MWE work, you could try to run it with:

XLA_FLAGS=--xla_gpu_shard_autotuning=false  mpirun -n 8 python error.py

Let me know if this workaround helps

sergachev commented 2 months ago

https://github.com/openxla/xla/pull/13108 was reverted.

--xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself.

sergachev commented 2 months ago

I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like https://github.com/openxla/xla/pull/13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail.

MasterSkepticista commented 2 months ago

@vfdev-5 Your suggestion worked. @sergachev I observed that JAX was built against https://github.com/openxla/xla/commit/95e3eea8d2aebd55160ed4185a38345ae98ab500, which was before the revert

sergachev commented 2 months ago

I sent a fix to XLA which makes the reproducer from this bug work. Independent of that, sharded autotuning got enabled yesterday again and it will likely get into the next JAX release.