NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[PyTorch] Avoid select op in PyTorch extensions #865

Closed timmoon10 closed 4 weeks ago

timmoon10 commented 1 month ago

Description

While profiling CPU overheads in te.Linear with FP8 support enabled, I find that a non-trivial amount of CPU time is spent in cast-transpose functions:

image

Much of this runtime is spent indexing into FP8 scaling factors. It takes a ~5 us on CPU to create each PyTorch tensor, and this can become significant when GEMM kernels take ~100 us on GPU.

This PR makes some backward-compatible changes to some of the PyTorch extensions to avoid calling the select op, and thus avoiding some tensor creation. It adds optional arguments for tensor offsets so that we can do pointer arithmetic instead of relying on PyTorch. With these changes, I see a 1.05x speedup in the forward pass of very small te.Linears.

Type of change

Changes

Checklist:

timmoon10 commented 1 month ago

/te-ci pytorch

timmoon10 commented 1 month ago

/te-ci pytorch

timmoon10 commented 1 month ago

/te-ci pytorch

timmoon10 commented 1 month ago

/te-ci pytorch