Closed hawkinsp closed 2 years ago
It's hard to be certain, but setting NCCL_LAUNCH_MODE=GROUP
seems to be helping in my testing. (I can't be sure, because the reproduction I have isn't deterministic and can take a while to exhibit the problem.)
edit: Unfortunately it eventually deadlocked. It just took longer.
After locally upgrading NCCL to 2.12.12 the repro hasn't deadlocked yet for me on 8xV100 (...of course that doesn't mean it won't deadlock, but it's encouraging.) I'm using the repro from https://github.com/google/jax/discussions/10763#discussioncomment-2812199
If you'd like to try this yourself, then you can either install the community-authored Conda build of JAX, which I believe uses its own copy of NCCL, or you can build jaxlib from source using your locally-installed copy of NCCL instead of its own vendored copy, like this:
TF_NCCL_VERSION
matching your installed NCCL version:
TF_NCCL_VERSION=2.12.12 python build/build.py --enable_cuda
jax
from your source tree as well so the jax
and jaxlib
versions are compatible).(I'm personally trying with XLA_FLAGS=--xla_gpu_enable_async_all_reduce=false
just to rule that out as a problem, but that's probably not needed.)
A NCCL 2.12.12 upgrade is already in progress although it's currently blocked on some unrelated test breakage.
I have made some observations that may help to solve this issue. In my code, it looks like the hang appears when I switch more often between the train_step
and eval_step
. I have a code that hangs approximately after the same number of steps.
In the first version of this code, my evaluation dataset was a finite tf.Dataset so I had to create the data iterator before running the evaluation step, every time.
Then, I changed to an infinite dataset and create the evaluation data iteration, only once, before the training and evaluation loop. The process runs to completion.
I don't know whether it is a luck or really the issue but I'm running more experiments to see if it repetable.
@hawkinsp Is there a (official) docker image that we can use to build jax locally ?
@jrabary commented on 2022年6月4日 GMT+8 16:04:
I have made some observations that may help to solve this issue. In my code, it looks like the hang appears when I switch more often between the
train_step
andeval_step
. I have a code that hangs approximately after the same number of steps.
In the first version of this code, my evaluation dataset was a finite tf.Dataset so I had to create the data iterator before running the evaluation step, every time.
Then, I changed to an infinite dataset and create the evaluation data iteration, only once, before the training and evaluation loop. The process runs to completion.
I don't know whether it is a luck or really the issue but I'm running more experiments to see if it repetable.
Well, I used TensorFlow data pipeline to load data in my experiments. And I always create two infinite datasets only once (one for training and one for evaluation). But my training script always runs into this bug after several hours. I also encountered this bug when using a self-made data loader (written with threading.Thread
in Python). So I think it's not TensorFlow data's matter.
I left the repro running overnight and it's still going well. Given that it used to deadlock in 10s of minutes on 8xV100 I'm now reasonably confident it has fixed the problem.
Here's a TF fork with a NCCL upgrade (courtesy of the engineer who is doing that upgrade; I didn't do the work here): https://github.com/hawkinsp/tensorflow/tree/nccl
You should be able to check out that branch/fork of TF, together with JAX and build a jaxlib
with an updated NCCL like this:
python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/the/nccl/fork/of/tensorflow --enable_cuda
We don't have a docker image for builds that's easy to use unfortunately.
@hawkinsp To make sure I understand. Should I first build this TF fork, install it and then build Jax ?
I got the follow error when executing the build command above
ERROR: @org_tensorflow//tensorflow/compiler/xla/python:enable_gpu :: Error loading option @org_tensorflow//tensorflow/compiler/xla/python:enable_gpu: error loading package '': cannot load '@org_tensorflow//tensorflow:workspace3.bzl': no such file
Finally works. I forgot to checkout the mccl branch.
Got the following error after the build
AttributeError: module 'jaxlib.gpu_linalg' has no attribute 'rocm_lu_pivots_to_permutation'
@hawkinsp Any idea ?
Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.
Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.
@hawkinsp we faced a similar-ish behavior with t5x ... I'll tag few people later when I'm able to DM them (to know who is the most relevant person)
After locally upgrading NCCL to 2.12.12 the repro hasn't deadlocked yet for me on 8xV100 (...of course that doesn't mean it won't deadlock, but it's encouraging.) I'm using the repro from #10763 (comment)
If you'd like to try this yourself, then you can either install the community-authored Conda build of JAX, which I believe uses its own copy of NCCL, or you can build jaxlib from source using your locally-installed copy of NCCL instead of its own vendored copy, like this:
- Install CUDA, CuDNN and NCCL, including the development headers, etc.
- Build jaxlib from source, passing
TF_NCCL_VERSION
matching your installed NCCL version:TF_NCCL_VERSION=2.12.12 python build/build.py --enable_cuda
- Install the newly built jaxlib (and you'll probably need to install
jax
from your source tree as well so thejax
andjaxlib
versions are compatible).- Profit?
(I'm personally trying with
XLA_FLAGS=--xla_gpu_enable_async_all_reduce=false
just to rule that out as a problem, but that's probably not needed.)A NCCL 2.12.12 upgrade is already in progress although it's currently blocked on some unrelated test breakage.
is the community author build in the conda-forge channel?
Experiencing the same issue. The issue is happening with compute instances with 8xV100 (SXM2) GPUs. But not with 2xV100 (PCIE) or 4xV100 (PCIE) instances.
spoke too soon :-) It is happening with 4xV100 instance too
I left the repro running overnight and it's still going well. Given that it used to deadlock in 10s of minutes on 8xV100 I'm now reasonably confident it has fixed the problem.
Here's a TF fork with a NCCL upgrade (courtesy of the engineer who is doing that upgrade; I didn't do the work here): https://github.com/hawkinsp/tensorflow/tree/nccl
You should be able to check out that branch/fork of TF, together with JAX and build a
jaxlib
with an updated NCCL like this:python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/the/nccl/fork/of/tensorflow --enable_cuda
We don't have a docker image for builds that's easy to use unfortunately.
this solution seems to be working so far...fingers crossed ;-) Training loop has been running successfully for the last two days! Thank you @hawkinsp
I used the following command from Jax directory
TF_NCCL_VERSION=2.12.12 python build/build.py --bazel_options=--override_repository=org_tensorflow=path-to-cloned-tensorflow-dir --bazel_options=--symlink_prefix=/ --enable_cuda
Are you using the main branch of jax for the build ?
My build is successful but when I import jax I get a rocm_lu_pivots_to_permutation
.
Are you using the main branch of jax for the build ? My build is successful but when I import jax I get a
rocm_lu_pivots_to_permutation
.
yep. I used the main branch. If helpful, I installed cuda 11.6, latest cudnn version and nccl 2.12.12 before building Jax/jaxlib
@ibulu Are you using jax
from head together with your newly-built jaxlib
? Are you sure? I think the symptom you have here is a version mismatch.
At head, the function is called hip_lu_pivots_to_permutation
, so my guess is your jax
is too old for your jaxlib
. What versions do you have installed?
@ibulu Are you using
jax
from head together with your newly-builtjaxlib
? Are you sure? I think the symptom you have here is a version mismatch.At head, the function is called
hip_lu_pivots_to_permutation
, so my guess is yourjax
is too old for yourjaxlib
. What versions do you have installed?
no issues on my end @hawkinsp . The solution you suggested is working for me!
At head, the function is called
hip_lu_pivots_to_permutation
, so my guess is yourjax
is too old for yourjaxlib
. What versions do you have installed?
Finally managed to build it. So far, there is no hang.
Thanks for all the help.
Having trouble getting this to build. Running into this error.
nccl/enhcompat.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o' -iquote external/nccl_arch[72/1930]
te bazel-out/k8-opt/bin/external/nccl_archive -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -Ibaz
el-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs -Ib
azel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtu
al_includes/src_hdrs -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/wc_store_fence_wa -isystem external/local_config_cuda/cuda
-isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt
/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME_
_="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canoni
cal-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-stringop-truncation -Wno-
array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '-Xcuda-fatbinary=--compress-all' '--
cuda-gpu-arch=sm_35' '--cuda-gpu-arch=sm_52' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-include-ptx=sm_80' '--cuda-gpu-arch=sm_80
' -c external/nccl_archive/src/enhcompat.cc -o bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o)
# Configuration: 65ddd241d57ae07a06fa3c2dfe98cea411398d066481c7622c029211f161d012
# Execution platform: @local_execution_config_platform//:platform
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in
a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in
a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
external/nccl_archive/src/enhcompat.cc(13): error: more than one instance of overloaded function "cudaStreamGetCaptureInfo_v2" has "C" linkage
external/nccl_archive/src/enhcompat.cc(15): error: more than one instance of overloaded function "cudaUserObjectCreate" has "C" linkage
external/nccl_archive/src/enhcompat.cc(17): error: more than one instance of overloaded function "cudaGraphRetainUserObject" has "C" linkage
external/nccl_archive/src/enhcompat.cc(19): error: more than one instance of overloaded function "cudaStreamUpdateCaptureDependencies" has "C"
linkage
external/nccl_archive/src/enhcompat.cc(21): error: more than one instance of overloaded function "cudaGetDriverEntryPoint" has "C" linkage
5 errors detected in the compilation of "external/nccl_archive/src/enhcompat.cc".
Thanks for all the help.
Having trouble getting this to build. Running into this error.
nccl/enhcompat.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o' -iquote external/nccl_arch[72/1930] te bazel-out/k8-opt/bin/external/nccl_archive -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -Ibaz el-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs -Ib azel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtu al_includes/src_hdrs -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/wc_store_fence_wa -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt /bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME_ _="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canoni cal-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-stringop-truncation -Wno- array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '-Xcuda-fatbinary=--compress-all' '-- cuda-gpu-arch=sm_35' '--cuda-gpu-arch=sm_52' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-include-ptx=sm_80' '--cuda-gpu-arch=sm_80 ' -c external/nccl_archive/src/enhcompat.cc -o bazel-out/k8-opt/bin/external/nccl_archive/_objs/nccl/enhcompat.pic.o) # Configuration: 65ddd241d57ae07a06fa3c2dfe98cea411398d066481c7622c029211f161d012 # Execution platform: @local_execution_config_platform//:platform nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning). nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning). external/nccl_archive/src/enhcompat.cc(13): error: more than one instance of overloaded function "cudaStreamGetCaptureInfo_v2" has "C" linkage external/nccl_archive/src/enhcompat.cc(15): error: more than one instance of overloaded function "cudaUserObjectCreate" has "C" linkage external/nccl_archive/src/enhcompat.cc(17): error: more than one instance of overloaded function "cudaGraphRetainUserObject" has "C" linkage external/nccl_archive/src/enhcompat.cc(19): error: more than one instance of overloaded function "cudaStreamUpdateCaptureDependencies" has "C" linkage external/nccl_archive/src/enhcompat.cc(21): error: more than one instance of overloaded function "cudaGetDriverEntryPoint" has "C" linkage 5 errors detected in the compilation of "external/nccl_archive/src/enhcompat.cc".
had the same error...following solved the issue for me...as the error suggests, the old nccl version is still installed...adding "TF_NCCL_VERSION=2.12.12" in front of the python build/build.py .... solved the issue for me
@ibulu You're a hero. Thanks so much.
One other change we had to make was to fix this line in the jax build.
--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel
+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel
- "@org_tensorflow//tensorflow/compiler/mlir/hlo:python/MlirHloModule.cc",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:python/MlirHloModule.cpp",
Running now :crossed_fingers:
We've released jax
and jaxlib
0.3.14 which contain a new version of NCCL. To the best of my knowledge, this should fix this problem. Hope that helps!
Strangely, the new version still hangs with 16 GPUs but it works well with 8 GPUs. Setup: 16 x A100 cudnn 8.2 cuda 11.3
@srxzr 16 GPUs in one single node or 2 nodes with 8 GPUs each ? As far as I know, the support of multi-node GPU is not working yet with Jax
@hawkinsp it seems based on @srxzr input this is not fixed?
@srxzr Could you please open a new issue with instructions on how to reproduce?
@jrabary Actually, multi-node GPU should work: the main thing that needs work to my knowledge is the documentation. The main thing you need to know is to use the jax.distributed
to initialize the cluster, and then follow the usual rules for multiprocess programming in JAX (see the docs).
@hawkinsp, I may have encountered the same issue (unless there's a coding mistake I'm missing). Executing the code on a single device will run as expected. However, if n_devices > 1, I'm observing all devices will allocate memory but an inconsistent number will perform computation (as displayed via nvtop / nvidia-smi), leading to an indefinite hang.
Environment:
While the actual codebase is far more complex, I believe I've identified a simple repro:
def func(num_rows_per_device, num_devices, num_features):
num_entries = num_rows_per_device * num_devices * num_features
# often hangs
def fn(x_):
return jax.lax.cond(
x_.sum() > num_entries,
lambda: 0.,
lambda: jnp.sum(x_**2)
)
# doesn't seem to hang
# def fn(x_):
# return x_.sum()
data = jnp.arange(num_entries).reshape(num_rows_per_device, num_devices, num_features).astype(jnp.float32)
return jax.lax.map(jax.pmap(fn), data)
print(jax.jit(func, static_argnames=["num_rows_per_device", "num_devices", "num_features"])(5, 4, 2))
Thoughts?
@ntenenz Can you open a new issue for this?
And for what it's worth, I can't reproduce a hang with the script you provided on a 4xT4 GPU VM.
Discussed in https://github.com/google/jax/discussions/10763