NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

Support for overlapping tensor-parallel collectives with matmuls in fprop? #737

Open cbcase opened 6 months ago

cbcase commented 6 months ago

Hi all -- we use TE for efficient fp8 inference, and I am looking at ways to reduce the communication overhead for multi-GPU tensor-parallel models. I have done my own experiments with chunking up matmuls and doing a pipelined matmul-allreduce, but the interference overheads are too high to provide any speedup. (This is even with cuda graphs and careful tuning of NCCL's CTA count to get good CTA rasterization.)

I see evidence in the code of using the userbuffers library to implement a much more efficient (presumably) pipelined matmul+[collective] implementation, but it is wildly unclear to me whether this is supported, under development, etc. -- and how to even use what's there. Is there guidance / documentation on what you guys are supporting / planning to support for efficient pipelined overlap and how to use it?

(Fwiw, as one data point, the bprop case is much less interesting, since the relative fraction of iteration time spent in the tensor-parallel collectives is so much lower.)

Thanks!

ptrendx commented 5 months ago

Hi Carl :-).

We are experimenting with multiple different ways of doing overlap - pipelined GEMM + communication similar to your experiments (both using NCCL and userbuffers library) as well as using an experimental cuBLAS feature where a GEMM kernel communicates the completion of tiles by flipping atomic switch.

We are working on better exposure (and documentation) of those features, beyond what is currently exposed in Megatron Core/NeMo. One of the first steps in this direction is #760 which will enable more seamless usage of UB library (currently it requires compiling TE with a special flag), which should provide lower overheads and possibly make the overlap viable for your usecase.

cbcase commented 5 months ago

Hey! Thanks, that's all encouraging to hear. At a glance, the idea of having the gemm kernel itself do the communication in the epilogue is very attractive, since it saves a lot on memory bandwidth, and in the common case of two-way model-parallel, the switch reductions don't buy you anything anyway. Is that exposed or still internal only?

Regardless, I'll keep watching this space. Happy to talk with anyone on the team directly, too, if you guys need any external test cases. (DM for email address, or PiotrB can find me in the PyTorch slack.)