NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[Feature Request] Integration of DiT components into TransformerEngine. #900

Open okotaku opened 3 weeks ago

okotaku commented 3 weeks ago

References

Two academic papers are cited to support the request:

Scalable Diffusion Models with Transformers: https://arxiv.org/abs/2212.09748 PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis: https://arxiv.org/abs/2310.00426

Current Support

TransformerEngine does not yet support DiT.

Differences Noted

Specific differences such as LN elementwise_affine=False and Transformer layer with Time step aware scale / shift are highlighted.

Screenshot 2024-06-10 at 10 24 40