NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[PyTorch] FP8 AllToAll #854

Closed yaox12 closed 1 month ago

yaox12 commented 1 month ago

Description

Add a FP8AllToAll layer, which conducts cast_to_fp8 -> all_to_all in fp8 -> cast_from_fp8. We're getting about 5% end to end performance gain in Mixtral 8x7B and 8x22B training with the parallelism configs where alltoall happens on inter-node connections such as IB/RoCE.

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

yaox12 commented 1 month ago

The implementation looks reasonable, although this functionality seems somewhat Mixtral-specific. Perhaps it would be better to have this class live inside NeMo?

This implementation is used to improve the performance of AllToAllTokenDispatcher in MCore, which is used not only in Mixtral but also other MoE models, such as GPT-MoE, DBRX, Grok and so on. But you suggestion is reasonable. It should be fine to live in MCore. I will discuss with other folks in MoE development.