Open MasterSkepticista opened 2 months ago
@MasterSkepticista the error is related with getting cuda:gemm_fusion_autotuning_results
on shards and maybe related to https://github.com/openxla/xla/pull/13108 (cc @sergachev). To disable the autotuning and to make your MWE work, you could try to run it with:
XLA_FLAGS=--xla_gpu_shard_autotuning=false mpirun -n 8 python error.py
Let me know if this workaround helps
https://github.com/openxla/xla/pull/13108 was reverted.
--xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself.
I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like https://github.com/openxla/xla/pull/13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail.
@vfdev-5 Your suggestion worked. @sergachev I observed that JAX was built against https://github.com/openxla/xla/commit/95e3eea8d2aebd55160ed4185a38345ae98ab500, which was before the revert
Description
Hi,
jax.jit
on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in
jax==0.4.30
.System info (python version, jaxlib version, accelerator, etc.)
Error log
```shell JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s) Hello from process 3 holding 1 device(s) Hello from process 5 holding 1 device(s) Hello from process 1 holding 1 device(s) Hello from process 7 holding 1 device(s) Hello from process 0 holding 1 device(s) Hello from process 4 holding 1 device(s) Hello from process 6 holding 1 device(s) Hello from process 2 holding 1 device(s) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/karan/workspace/jax_gpt2/error.py", line 14, in