NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.98k stars 331 forks source link

How about the torch.compile in TransformerEngine ? #1241

Open south-ocean opened 1 month ago

south-ocean commented 1 month ago

In PyTorch, we know that Torch.Compile will bring us a lot of benefits, and the TransformerEngine also brings performance improvements through strategies such as Transformer fusion optimization, so does the Transformer Engine also support Torch.compile? Is there any documentation on whether it is possible to get better benefits by using torch.compile in TE mode compared to non-TE mode? Do you have suggestions for me to use torch.compile in TransformerEngine?

In llama2, we found that torch.compile can make better profits on rmsnorm and swiglu, but in TE, it is not possible to directly add torch.compile to rmsnorm and swig;u, is there any good way?

timmoon10 commented 1 month ago

We have used torch.compile to fuse some operations like bias+GeLU in LayerNormMLP (see bias_gelu_fused_). However, we have not yet done serious work applying torch.compile to FP8 kernels since we're not sure how well they can accommodate the extra logic for FP8 scaling factors and absmax reductions. It's something we've kept in mind though, especially as a means to work around the CPU overheads of running PyTorch in eager mode.

For the moment we manually identify fusion opportunities and incorporate them into our modules, e.g. LayerNormLinear might call a LayerNorm kernel that outputs in FP8. For more flexibility, we are experimenting with a modular operation-based API that can automatically identify some of these fusion opportunites. I believe Lightning Thunder has also been working on automatic kernel fusion with TE.

south-ocean commented 1 month ago

yeah, But now i am runing on bfloat16. Through add torch.compile to rmsnorm and swiglu for llama2-7b in legacy mode, I can get more benefit than te, leagacy mode through torch.compile can brings 10% performance improvements than te, So can i compile the benefit with te and torch.compile, I think it can be more fast.

MaciejBalaNV commented 1 month ago

@timmoon10 Please also consider the following: It's quite popular to use torch.compile on the entire model and fuse all inefficient operations on tensors. Moving all such operations into small compilable functions so that we don't have to compile the entire model, is often not practical and not easy to maintain. However, currently TE modules do not work with torch.compile and a graph break is introduced at every TE module usage. In most of the cases it nullifies all possible performance benefits from torch.compile. Registering TE modules and custom kernels with torch.compile so that they do not introduce graph breaks would be a huge improvement. FlashAttention repository did this recently: https://github.com/Dao-AILab/flash-attention/pull/1139

south-ocean commented 1 month ago

@timmoon10 Yeah, I have now found that the performance of using TE does not exceed the benefits of non-TE+torch.compile in llama2-7b, except for FP8 support, the functions for FA, TE and non-TE calls are the same and the linear layer is also called by the blaslt. So for the rest of the parts, although te did some fusion, will the benefits of TE exceed the improvement brought by torch.compile? Do you have any suggestions? How we're taking performance even further.