google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.94k stars 2.74k forks source link

Jax multi-gpu randomly hangs forever #10969

Closed hawkinsp closed 2 years ago

hawkinsp commented 2 years ago

Discussed in https://github.com/google/jax/discussions/10763

Originally posted by **jrabary** May 19, 2022 Hi, We are facing a problem where a training and validation code based on jax/flax hangs randomly on a multi-gpu host. Using a single GPU is working correctly but once we add multi-gpu support it hangs in an unpredictable way. The GPUs usage is at 0% for all GPU but the CPU is used. What could be the problem ?
hawkinsp commented 2 years ago

It's hard to be certain, but setting NCCL_LAUNCH_MODE=GROUP seems to be helping in my testing. (I can't be sure, because the reproduction I have isn't deterministic and can take a while to exhibit the problem.)

edit: Unfortunately it eventually deadlocked. It just took longer.

hawkinsp commented 2 years ago

After locally upgrading NCCL to 2.12.12 the repro hasn't deadlocked yet for me on 8xV100 (...of course that doesn't mean it won't deadlock, but it's encouraging.) I'm using the repro from https://github.com/google/jax/discussions/10763#discussioncomment-2812199

If you'd like to try this yourself, then you can either install the community-authored Conda build of JAX, which I believe uses its own copy of NCCL, or you can build jaxlib from source using your locally-installed copy of NCCL instead of its own vendored copy, like this:

(I'm personally trying with XLA_FLAGS=--xla_gpu_enable_async_all_reduce=false just to rule that out as a problem, but that's probably not needed.)

A NCCL 2.12.12 upgrade is already in progress although it's currently blocked on some unrelated test breakage.

jrabary commented 2 years ago

I have made some observations that may help to solve this issue. In my code, it looks like the hang appears when I switch more often between the train_step and eval_step. I have a code that hangs approximately after the same number of steps. In the first version of this code, my evaluation dataset was a finite tf.Dataset so I had to create the data iterator before running the evaluation step, every time. Then, I changed to an infinite dataset and create the evaluation data iteration, only once, before the training and evaluation loop. The process runs to completion. I don't know whether it is a luck or really the issue but I'm running more experiments to see if it repetable.

jrabary commented 2 years ago

@hawkinsp Is there a (official) docker image that we can use to build jax locally ?

YOUSIKI commented 2 years ago

@jrabary commented on 2022年6月4日 GMT+8 16:04:

I have made some observations that may help to solve this issue. In my code, it looks like the hang appears when I switch more often between the train_step and eval_step. I have a code that hangs approximately after the same number of steps.
In the first version of this code, my evaluation dataset was a finite tf.Dataset so I had to create the data iterator before running the evaluation step, every time.
Then, I changed to an infinite dataset and create the evaluation data iteration, only once, before the training and evaluation loop. The process runs to completion.
I don't know whether it is a luck or really the issue but I'm running more experiments to see if it repetable.

Well, I used TensorFlow data pipeline to load data in my experiments. And I always create two infinite datasets only once (one for training and one for evaluation). But my training script always runs into this bug after several hours. I also encountered this bug when using a self-made data loader (written with threading.Thread in Python). So I think it's not TensorFlow data's matter.

hawkinsp commented 2 years ago

I left the repro running overnight and it's still going well. Given that it used to deadlock in 10s of minutes on 8xV100 I'm now reasonably confident it has fixed the problem.

Here's a TF fork with a NCCL upgrade (courtesy of the engineer who is doing that upgrade; I didn't do the work here): https://github.com/hawkinsp/tensorflow/tree/nccl

You should be able to check out that branch/fork of TF, together with JAX and build a jaxlib with an updated NCCL like this:

python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/the/nccl/fork/of/tensorflow  --enable_cuda

We don't have a docker image for builds that's easy to use unfortunately.

jrabary commented 2 years ago

@hawkinsp To make sure I understand. Should I first build this TF fork, install it and then build Jax ?

I got the follow error when executing the build command above

ERROR: @org_tensorflow//tensorflow/compiler/xla/python:enable_gpu :: Error loading option @org_tensorflow//tensorflow/compiler/xla/python:enable_gpu: error loading package '': cannot load '@org_tensorflow//tensorflow:workspace3.bzl': no such file
jrabary commented 2 years ago

Finally works. I forgot to checkout the mccl branch.

jrabary commented 2 years ago

Got the following error after the build

AttributeError: module 'jaxlib.gpu_linalg' has no attribute 'rocm_lu_pivots_to_permutation'

@hawkinsp Any idea ?

ibulu commented 2 years ago

Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.

mjsML commented 2 years ago

Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.

@hawkinsp we faced a similar-ish behavior with t5x ... I'll tag few people later when I'm able to DM them (to know who is the most relevant person)

ibulu commented 2 years ago

After locally upgrading NCCL to 2.12.12 the repro hasn't deadlocked yet for me on 8xV100 (...of course that doesn't mean it won't deadlock, but it's encouraging.) I'm using the repro from #10763 (comment)

If you'd like to try this yourself, then you can either install the community-authored Conda build of JAX, which I believe uses its own copy of NCCL, or you can build jaxlib from source using your locally-installed copy of NCCL instead of its own vendored copy, like this:

  • Install CUDA, CuDNN and NCCL, including the development headers, etc.
  • Build jaxlib from source, passing TF_NCCL_VERSION matching your installed NCCL version:
    TF_NCCL_VERSION=2.12.12 python build/build.py --enable_cuda
  • Install the newly built jaxlib (and you'll probably need to install jax from your source tree as well so the jax and jaxlib versions are compatible).
  • Profit?

(I'm personally trying with XLA_FLAGS=--xla_gpu_enable_async_all_reduce=false just to rule that out as a problem, but that's probably not needed.)

A NCCL 2.12.12 upgrade is already in progress although it's currently blocked on some unrelated test breakage.

is the community author build in the conda-forge channel?

ibulu commented 2 years ago

Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.

spoke too soon :-) It is happening with 4xV100 instance too

ibulu commented 2 years ago

I left the repro running overnight and it's still going well. Given that it used to deadlock in 10s of minutes on 8xV100 I'm now reasonably confident it has fixed the problem.

Here's a TF fork with a NCCL upgrade (courtesy of the engineer who is doing that upgrade; I didn't do the work here): https://github.com/hawkinsp/tensorflow/tree/nccl

You should be able to check out that branch/fork of TF, together with JAX and build a jaxlib with an updated NCCL like this:

python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/the/nccl/fork/of/tensorflow  --enable_cuda

We don't have a docker image for builds that's easy to use unfortunately.

this solution seems to be working so far...fingers crossed ;-) Training loop has been running successfully for the last two days! Thank you @hawkinsp

I used the following command from Jax directory TF_NCCL_VERSION=2.12.12 python build/build.py --bazel_options=--override_repository=org_tensorflow=path-to-cloned-tensorflow-dir --bazel_options=--symlink_prefix=/ --enable_cuda

jrabary commented 2 years ago

Are you using the main branch of jax for the build ? My build is successful but when I import jax I get a rocm_lu_pivots_to_permutation.

ibulu commented 2 years ago

Are you using the main branch of jax for the build ? My build is successful but when I import jax I get a rocm_lu_pivots_to_permutation.

yep. I used the main branch. If helpful, I installed cuda 11.6, latest cudnn version and nccl 2.12.12 before building Jax/jaxlib

hawkinsp commented 2 years ago

@ibulu Are you using jax from head together with your newly-built jaxlib? Are you sure? I think the symptom you have here is a version mismatch.

At head, the function is called hip_lu_pivots_to_permutation, so my guess is your jax is too old for your jaxlib. What versions do you have installed?

ibulu commented 2 years ago

@ibulu Are you using jax from head together with your newly-built jaxlib? Are you sure? I think the symptom you have here is a version mismatch.

At head, the function is called hip_lu_pivots_to_permutation, so my guess is your jax is too old for your jaxlib. What versions do you have installed?

no issues on my end @hawkinsp . The solution you suggested is working for me!

jrabary commented 2 years ago

At head, the function is called hip_lu_pivots_to_permutation, so my guess is your jax is too old for your jaxlib. What versions do you have installed?

Finally managed to build it. So far, there is no hang.

srush commented 2 years ago

Thanks for all the help.

Having trouble getting this to build. Running into this error.

nccl/enhcompat.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o' -iquote external/nccl_arch[72/1930]
te bazel-out/k8-opt/bin/external/nccl_archive -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -Ibaz
el-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs -Ib
azel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtu
al_includes/src_hdrs -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/wc_store_fence_wa -isystem external/local_config_cuda/cuda
 -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt
/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME_
_="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canoni
cal-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-stringop-truncation -Wno-
array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '-Xcuda-fatbinary=--compress-all' '--
cuda-gpu-arch=sm_35' '--cuda-gpu-arch=sm_52' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-include-ptx=sm_80' '--cuda-gpu-arch=sm_80
' -c external/nccl_archive/src/enhcompat.cc -o bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o)
# Configuration: 65ddd241d57ae07a06fa3c2dfe98cea411398d066481c7622c029211f161d012
# Execution platform: @local_execution_config_platform//:platform
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in 
a future release (Use -Wno-deprecated-gpu-targets to suppress warning). 
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in 
a future release (Use -Wno-deprecated-gpu-targets to suppress warning). 
external/nccl_archive/src/enhcompat.cc(13): error: more than one instance of overloaded function "cudaStreamGetCaptureInfo_v2" has "C" linkage

external/nccl_archive/src/enhcompat.cc(15): error: more than one instance of overloaded function "cudaUserObjectCreate" has "C" linkage

external/nccl_archive/src/enhcompat.cc(17): error: more than one instance of overloaded function "cudaGraphRetainUserObject" has "C" linkage

external/nccl_archive/src/enhcompat.cc(19): error: more than one instance of overloaded function "cudaStreamUpdateCaptureDependencies" has "C"
 linkage                                                               

external/nccl_archive/src/enhcompat.cc(21): error: more than one instance of overloaded function "cudaGetDriverEntryPoint" has "C" linkage

5 errors detected in the compilation of "external/nccl_archive/src/enhcompat.cc".
ibulu commented 2 years ago

Thanks for all the help.

Having trouble getting this to build. Running into this error.

nccl/enhcompat.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o' -iquote external/nccl_arch[72/1930]
te bazel-out/k8-opt/bin/external/nccl_archive -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -Ibaz
el-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs -Ib
azel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtu
al_includes/src_hdrs -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/wc_store_fence_wa -isystem external/local_config_cuda/cuda
 -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt
/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME_
_="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canoni
cal-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-stringop-truncation -Wno-
array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '-Xcuda-fatbinary=--compress-all' '--
cuda-gpu-arch=sm_35' '--cuda-gpu-arch=sm_52' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-include-ptx=sm_80' '--cuda-gpu-arch=sm_80
' -c external/nccl_archive/src/enhcompat.cc -o bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o)
# Configuration: 65ddd241d57ae07a06fa3c2dfe98cea411398d066481c7622c029211f161d012
# Execution platform: @local_execution_config_platform//:platform
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in 
a future release (Use -Wno-deprecated-gpu-targets to suppress warning). 
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in 
a future release (Use -Wno-deprecated-gpu-targets to suppress warning). 
external/nccl_archive/src/enhcompat.cc(13): error: more than one instance of overloaded function "cudaStreamGetCaptureInfo_v2" has "C" linkage

external/nccl_archive/src/enhcompat.cc(15): error: more than one instance of overloaded function "cudaUserObjectCreate" has "C" linkage

external/nccl_archive/src/enhcompat.cc(17): error: more than one instance of overloaded function "cudaGraphRetainUserObject" has "C" linkage

external/nccl_archive/src/enhcompat.cc(19): error: more than one instance of overloaded function "cudaStreamUpdateCaptureDependencies" has "C"
 linkage                                                               

external/nccl_archive/src/enhcompat.cc(21): error: more than one instance of overloaded function "cudaGetDriverEntryPoint" has "C" linkage

5 errors detected in the compilation of "external/nccl_archive/src/enhcompat.cc".

had the same error...following solved the issue for me...as the error suggests, the old nccl version is still installed...adding "TF_NCCL_VERSION=2.12.12" in front of the python build/build.py .... solved the issue for me

srush commented 2 years ago

@ibulu You're a hero. Thanks so much.

One other change we had to make was to fix this line in the jax build.

--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel
+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel
-        "@org_tensorflow//tensorflow/compiler/mlir/hlo:python/MlirHloModule.cc",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:python/MlirHloModule.cpp",

Running now :crossed_fingers:

hawkinsp commented 2 years ago

We've released jax and jaxlib 0.3.14 which contain a new version of NCCL. To the best of my knowledge, this should fix this problem. Hope that helps!

srxzr commented 2 years ago

Strangely, the new version still hangs with 16 GPUs but it works well with 8 GPUs. Setup: 16 x A100 cudnn 8.2 cuda 11.3

jrabary commented 2 years ago

@srxzr 16 GPUs in one single node or 2 nodes with 8 GPUs each ? As far as I know, the support of multi-node GPU is not working yet with Jax

mjsML commented 2 years ago

@hawkinsp it seems based on @srxzr input this is not fixed?

hawkinsp commented 2 years ago

@srxzr Could you please open a new issue with instructions on how to reproduce?

@jrabary Actually, multi-node GPU should work: the main thing that needs work to my knowledge is the documentation. The main thing you need to know is to use the jax.distributed to initialize the cluster, and then follow the usual rules for multiprocess programming in JAX (see the docs).

ntenenz commented 2 years ago

@hawkinsp, I may have encountered the same issue (unless there's a coding mistake I'm missing). Executing the code on a single device will run as expected. However, if n_devices > 1, I'm observing all devices will allocate memory but an inconsistent number will perform computation (as displayed via nvtop / nvidia-smi), leading to an indefinite hang.

Environment:

While the actual codebase is far more complex, I believe I've identified a simple repro:

def func(num_rows_per_device, num_devices, num_features):
    num_entries = num_rows_per_device * num_devices * num_features

    # often hangs
    def fn(x_):
        return jax.lax.cond(
            x_.sum() > num_entries,
            lambda: 0.,
            lambda: jnp.sum(x_**2)
        )

    # doesn't seem to hang
    # def fn(x_):
    #     return x_.sum()

    data = jnp.arange(num_entries).reshape(num_rows_per_device, num_devices, num_features).astype(jnp.float32)
    return jax.lax.map(jax.pmap(fn), data)

print(jax.jit(func, static_argnames=["num_rows_per_device", "num_devices", "num_features"])(5, 4, 2))

Thoughts?

hawkinsp commented 2 years ago

@ntenenz Can you open a new issue for this?

And for what it's worth, I can't reproduce a hang with the script you provided on a 4xT4 GPU VM.