A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
In the DiT of PaddleMIX, the Linear(parallel_mode='column') is not used in pairs with Linear(parallel_mode='row').
In order to use transformer engine backend in DiT, the output of te.Linear(parallel_mode='column') should to be all-gathered when forward and reduce-scattered when backward.
Fixes # (issue)
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
[ ] Infra/Build change
[ ] Code refractor
Changes
Please list the changes introduced in this PR:
Add all-gather and reduce-scatter to the forward and backward logic of te.Linear(parallel_mode='column')
Description
In the DiT of PaddleMIX, the Linear(parallel_mode='column') is not used in pairs with Linear(parallel_mode='row'). In order to use transformer engine backend in DiT, the output of te.Linear(parallel_mode='column') should to be all-gathered when forward and reduce-scattered when backward.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: