Open mathemakitten opened 1 year ago
Hi @mathemakitten stumbled on this issue while researching on Jax, the latest efa nccl plugin supports multiple NCCL version with a single build. https://github.com/aws/aws-ofi-nccl, if you checkout the installation guide, it no longer requires a fixed nccl version. It also tries to set FI_EFA_FORK_SAFE
and NCCL_PROTO
here. Unfortunately you still have to set the environment variable FI_EFA_USE_DEVICE_RDMA
manually. More documentation on the env vars. Hope this helps!
An update here: JAX no longer vendors a copy of NCCL. If you use the pip
installation of JAX, we'll install NCCL via a pip
package (nvidia-nccl-cu12
). If the pip
package isn't installed, then we'll use whatever NCCL library is in your LD_LIBRARY_PATH
.
This means that it's easy to use an AWS-specific NCCL, if that's something you want to do: just make sure to remove the pip
package version of NCCL and install the AWS version.
(We should document this.)
Hi! I wanted to surface a quirk of running jax on an AWS GPU cluster in case it's helpful — jax's vendored NCCL doesn't play well with AWS Libfabric EFA due to the way that jax starts processes (issue here). I know this is really an AWS-side issue but it might be helpful to have it documented somewhere on the jax side as well since it stalled me for two days, and could possibly be helpful to debug issues. It's supposedly fixed in the newest version of EFA, but I still had to set
FI_EFA_FORK_SAFE=1
to get anything working for multinode, multiprocess jax.Relatedly, it'd be super useful to have a little utility function to surface what version of NCCL jax is running with—or maybe there's a way to do this and I'm just missing it. I'm working in envs with multiple versions of NCCL floating around and considered building jaxlib against a different version at some point so it would be useful!
Feel free to close this if you decide against these suggestions, I thought it'd be helpful to document it for searchability in case others are running jax on specifically AWS clusters.