Open EthanChen1234 opened 1 week ago
To use RMSNorm by itself, you can simply construct a te.RMSNorm
module:
import torch
import transformer_engine.pytorch as te
# TE module
layer = te.RMSNorm(128)
# Synthetic data
x = torch.randn(128, 128).cuda()
# Forward and backward pass
y = layer(x)
y.sum().backward()
If you know it's going to be followed by a linear operation, it may be worthwhile using the te.LayerNormLinear
or te.LayerNormMLP
modules:
layer = te.LayerNormLinear(128, 128, normalization="RMSNorm")
y = layer(x)
This allows for some kernel fusions when running with FP8, e.g. fusing the RMSNorm with an FP8 cast.
hi, TE is really a great job.
how to use in FusedRMSNorm in TE?
https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm.py#L329