A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
While profiling CPU overheads in te.Linear with FP8 support enabled, I find that a non-trivial amount of CPU time is spent in cast-transpose functions:
Much of this runtime is spent indexing into FP8 scaling factors. It takes a ~5 us on CPU to create each PyTorch tensor, and this can become significant when GEMM kernels take ~100 us on GPU.
This PR makes some backward-compatible changes to some of the PyTorch extensions to avoid calling the select op, and thus avoiding some tensor creation. It adds optional arguments for tensor offsets so that we can do pointer arithmetic instead of relying on PyTorch. With these changes, I see a 1.05x speedup in the forward pass of very small te.Linears.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[x] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
Add FP8 scaling factor offset args to cast-transpose PyTorch extensions
Add FP8 scaling factor offset args to LayerNorm and RMSNorm PyTorch extensions
Description
While profiling CPU overheads in
te.Linear
with FP8 support enabled, I find that a non-trivial amount of CPU time is spent in cast-transpose functions:Much of this runtime is spent indexing into FP8 scaling factors. It takes a ~5 us on CPU to create each PyTorch tensor, and this can become significant when GEMM kernels take ~100 us on GPU.
This PR makes some backward-compatible changes to some of the PyTorch extensions to avoid calling the select op, and thus avoiding some tensor creation. It adds optional arguments for tensor offsets so that we can do pointer arithmetic instead of relying on PyTorch. With these changes, I see a 1.05x speedup in the forward pass of very small
te.Linear
s.Type of change
Changes
Checklist: