jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

NCCL and compatibility with AWS EFA #14534

Open mathemakitten opened 1 year ago

mathemakitten commented 1 year ago

Hi! I wanted to surface a quirk of running jax on an AWS GPU cluster in case it's helpful — jax's vendored NCCL doesn't play well with AWS Libfabric EFA due to the way that jax starts processes (issue here). I know this is really an AWS-side issue but it might be helpful to have it documented somewhere on the jax side as well since it stalled me for two days, and could possibly be helpful to debug issues. It's supposedly fixed in the newest version of EFA, but I still had to set FI_EFA_FORK_SAFE=1 to get anything working for multinode, multiprocess jax.

Relatedly, it'd be super useful to have a little utility function to surface what version of NCCL jax is running with—or maybe there's a way to do this and I'm just missing it. I'm working in envs with multiple versions of NCCL floating around and considered building jaxlib against a different version at some point so it would be useful!

Feel free to close this if you decide against these suggestions, I thought it'd be helpful to document it for searchability in case others are running jax on specifically AWS clusters.

roywei commented 1 year ago

Hi @mathemakitten stumbled on this issue while researching on Jax, the latest efa nccl plugin supports multiple NCCL version with a single build. https://github.com/aws/aws-ofi-nccl, if you checkout the installation guide, it no longer requires a fixed nccl version. It also tries to set FI_EFA_FORK_SAFE and NCCL_PROTO here. Unfortunately you still have to set the environment variable FI_EFA_USE_DEVICE_RDMA manually. More documentation on the env vars. Hope this helps!

hawkinsp commented 9 months ago

An update here: JAX no longer vendors a copy of NCCL. If you use the pip installation of JAX, we'll install NCCL via a pip package (nvidia-nccl-cu12). If the pip package isn't installed, then we'll use whatever NCCL library is in your LD_LIBRARY_PATH.

This means that it's easy to use an AWS-specific NCCL, if that's something you want to do: just make sure to remove the pip package version of NCCL and install the AWS version.

(We should document this.)