A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This PR moves Userbuffers and comm+GEMM overlap algorithms from TE/PyTorch to TE/common with refactored interfaces to remove the PyTorch dependency.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
[ ] Infra/Build change
[x] Code refractor
Changes
transformer_engine/pytorch/csrc/userbuffers moved to transformer_engine/common/comm_gemm_overlap/userbuffers.
transformer_engine/pytorch/csrc/comm_gemm_overlap.h split into transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h and transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp and refactored to remove torch::Tensor dependency.
Added new TE/PyTorch wrappers around the refactored comm+GEMM overlap algorithms.
Expanded unit tests to cover all overlap algorithms including atomic GEMM overlaps (tested as AG+RS pairs).
Description
This PR moves Userbuffers and comm+GEMM overlap algorithms from TE/PyTorch to TE/common with refactored interfaces to remove the PyTorch dependency.
Type of change
Changes
transformer_engine/pytorch/csrc/userbuffers
moved totransformer_engine/common/comm_gemm_overlap/userbuffers
.transformer_engine/pytorch/csrc/comm_gemm_overlap.h
split intotransformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
andtransformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
and refactored to removetorch::Tensor
dependency.Checklist: