A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Multi-node use cases for comm+GEMM overlap in NeMo started hanging at initialize_ub() after PR #901 was merged.
The cause was identified to be the assumption in PR #901 that the user-provided tensor-parallel group is always equivalent to the intra-node MPI communicator in the old MPI-based bootstrapping. As a consequence of this assumption, the new torch.distributed-based bootstrapping could potentially try to initialize Userbuffers with local rank/size information that does not span the entire physical node, thus causing the CUDA Multicast shareable handle communication logic (over Unix domain sockets) to silently fail and hang.
This PR re-implements the equivalent of the old MPI-based bootstrapping logic inside initialize_ub() via torch.distributed collectives. initialize_ub() API reverts back to the old interface where the user only provides a tp_size instead of the TP group. The intra-node process group is constructed internally by matching hostnames across ranks on the same physical node.
+@erhoo82 +@jbaczek for viz.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
[ ] Infra/Build change
[ ] Code refractor
Changes
Please list the changes introduced in this PR:
initialize_ub() API reverted to pre-PR #901 version with tp_size instead of tp_group.
The intra-node process group is now constructed inside initialize_ub() based on hostname matching across global ranks.
NCCL naming convention removed from userbuffers/ipcsocket.<h/cpp> (this is not NCCL code and is not dependent on NCCL).
Improved error messages when communicating CUDA Multicast handles over Unix domain sockets -- no more silent failures/hangs.
examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py updated to support mixed data- and tensor-parallelism with model replication over either multiple TP groups on a single node, or one TP group per physical node.
Added README file to the comm+GEMM overlap example with information on hardware/software prerequisites.
Description
Multi-node use cases for comm+GEMM overlap in NeMo started hanging at
initialize_ub()
after PR #901 was merged.The cause was identified to be the assumption in PR #901 that the user-provided tensor-parallel group is always equivalent to the intra-node MPI communicator in the old MPI-based bootstrapping. As a consequence of this assumption, the new
torch.distributed
-based bootstrapping could potentially try to initialize Userbuffers with local rank/size information that does not span the entire physical node, thus causing the CUDA Multicast shareable handle communication logic (over Unix domain sockets) to silently fail and hang.This PR re-implements the equivalent of the old MPI-based bootstrapping logic inside
initialize_ub()
viatorch.distributed
collectives.initialize_ub()
API reverts back to the old interface where the user only provides atp_size
instead of the TP group. The intra-node process group is constructed internally by matching hostnames across ranks on the same physical node.+@erhoo82 +@jbaczek for viz.
Type of change
Changes
Please list the changes introduced in this PR:
initialize_ub()
API reverted to pre-PR #901 version withtp_size
instead oftp_group
.initialize_ub()
based on hostname matching across global ranks.userbuffers/ipcsocket.<h/cpp>
(this is not NCCL code and is not dependent on NCCL).examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
updated to support mixed data- and tensor-parallelism with model replication over either multiple TP groups on a single node, or one TP group per physical node.Checklist: