NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.9k stars 321 forks source link

Need all-reduce for norm weight gradients with sequence parallel #435

Open jspark1105 opened 1 year ago

jspark1105 commented 1 year ago

https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_linear.py#L461-L471

When we use sequence parallel we need all-reduce norm weight gradients after the code above among TP groups?

timmoon10 commented 1 year ago

Sorry for the late reply. Yes, that's correct. Currently we expect that this all-reduce happens outside of TE, which allows us to coalesce multiple all-reduces into a single NCCL call.

Megatron-LM: https://github.com/NVIDIA/Megatron-LM/blob/52f13005148afa47a6f37b082083fa2c6675ae3e/megatron/optimizer/optimizer.py#L244

NeMo: https://github.com/NVIDIA/NeMo/blob/a9fb58bcee7bff0e50a621d05ec3a9b5eb5f584c/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L696

SuhitK commented 4 months ago

Hi @timmoon10 Is this still valid? Do we still need to handle the all-reduce outside of TE?