NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.59k stars 255 forks source link

how to use FusedRMSNorm? #948

Open EthanChen1234 opened 1 week ago

EthanChen1234 commented 1 week ago

hi, TE is really a great job.

how to use in FusedRMSNorm in TE?

https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm.py#L329

timmoon10 commented 5 days ago

To use RMSNorm by itself, you can simply construct a te.RMSNorm module:

import torch
import transformer_engine.pytorch as te

# TE module
layer = te.RMSNorm(128)

# Synthetic data
x = torch.randn(128, 128).cuda()

# Forward and backward pass
y = layer(x)
y.sum().backward()

If you know it's going to be followed by a linear operation, it may be worthwhile using the te.LayerNormLinear or te.LayerNormMLP modules:

layer = te.LayerNormLinear(128, 128, normalization="RMSNorm")
y = layer(x)

This allows for some kernel fusions when running with FP8, e.g. fusing the RMSNorm with an FP8 cast.