A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
te.Linear currently only supports TP overlap with parallel_mode="row" where it overlaps reduce-scatter in the forward pass, and all-gather with dgrad in the backward pass.
This PR adds new options to enable all-gather overlap in the forward pass, and reduce-scatter overlap with dgrad in the backward pass, when parallel_mode="column".
Fixes #1312
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[x] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Description
te.Linear
currently only supports TP overlap withparallel_mode="row"
where it overlaps reduce-scatter in the forward pass, and all-gather with dgrad in the backward pass.This PR adds new options to enable all-gather overlap in the forward pass, and reduce-scatter overlap with dgrad in the backward pass, when
parallel_mode="column"
.Fixes #1312
Type of change
Checklist: