Open umiswing opened 4 months ago
Hi @umiswing -- to clarify, the comm+GEMM overlap is currently only possible with single-node tensor-parallelism, but it does support multi-node data-parallelism. For reference, examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
initializes te.LayerNormMLP
with a single-node tp_group
, but replicates the model across different nodes via torch.nn.DistributedDataParallel
. This example problem also supports replicating the model within a single node (i.e. the TP size does not need to be equal to the node size).
If you're asking about multi-node tensor-parallelism, there is currently no way to do this in Transformer Engine. The existing communication backend for comm+GEMM overlap (Userbuffers) relies on CUDA Multicast handles that are only supported over a single-node NVLink domain of max 8 GPUs. There is an ongoing effort to add multi-node NVLink support to this backend (#815), but I do not have a timeline for when this will be ready.
Separately from this, we are also developing an alternative NVSHMEM-based communication backend that I hope to merge into TE/main within the next 2-3 months. This is ultimately meant to replace Userbuffers as the default backend and make comm+GEMM overlap possible without NVLink (though NVLink will always yield better performance). Since NVSHMEM itself already supports multi-node peer-to-peer comms, the new backend should also support multi-node tensor-parallelism out of the box.
Thanks @denera . And I have another question: What's the pros of using TE's communication backend over nccl? Is it because Userbuffers provides a finer-grained option for comm operation (e.g. explicitly setting of comm sm counts and ce, which is important to overlap gemm and comm) than nccl? Actually, I do can't find corresponding option in nccl.
This primarily has to do with the communication launch overhead in NCCL. Fine-grained overlap at the problem sizes we deal with in Transformer Engine requires frequent movement of small chunks of data, so the overhead in launching each one of these communication kernels adds up to a significant cost that erases the benefit we get from the overlap.
Userbuffers avoids this overhead by implementing the communication via CUDA shared memory access across GPUs in the same NVLink domain, but of course this comes with the limitation of restricting the data distribution to a single node of max 8 GPUs (max 32 GPUs when we merge multi-node NVLink support later this year).
@umiswing Another reason why TE uses Userbuffers for comm overlap is UB supports strided read/write and the atomic counter.
PS: To set number of sms for comm in NCCL, you could set NCCL_MAX_CTAS/NCCL_MIN_CTAS (after v2.17) or NCCL_MIN_NCHANNELS/NCCL_MAX_NCHANNELS.
Thanks @denera @yaox12 . And how does Userbuffers affect GEMM's performance? I know nccl can slow down GEMM when overlapping. And I also see options to set comm sm num and whether to use copy engine in Userbuffers' source code, are these options for better overlapping performance? Do I need to tune comm-gemm-overlap's option for better performance?
@umiswing The affect on GEMM performance depends on the particular overlap algorithm and its configuration.
For example, layers that are overlapped with a ring-exchange
method should not impact GEMM performance because they rely exclusively on point-to-point send/recv comms that can be executed by the Copy Engine without taking up any SMs away from GEMM compute.
However, this is not always the most performant choice for every GEMM layer in every model. Some overlaps are more performant with a pipeline
method that requires reserving some SMs for the comm kernels, and therefore inevitably impact the performance of the GEMM kernel. Furthermore, the optimal number of SMs you would want to reserve for these layers is also not necessarily a constant across layers.
The defaults configuration for all overlap layers in TE is informed partially by our past MLPerf submissions, but they still require tuning each specific application and the hardware you train on. It's unlikely that they will be optimal for you out of the box.
I want to use te's comm-gemm-overlap module to perform multi-node training, however the readme says this module only support single node. Does te have any plan for multi nodes support? And what effort should I do if I want to custom this module and use it for multi-node computation? Any resources I can refer to?