NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

MLP without LayerNorm #817

Open sriniiyer opened 2 months ago

sriniiyer commented 2 months ago

Is there currently a way to use MLP without applying the LayerNorm? What would be the best way to implement this? Thanks!

timmoon10 commented 3 weeks ago

The simplest solution would be to manually construct an MLP out of multiple te.Linears, but this won't be able to do all of the kernel fusions in te.LayerNormMLP.

Long-term, this kind of customization is the purpose of the operation-based API being developed in https://github.com/NVIDIA/TransformerEngine/pull/707:

mlp = te.Sequential(
    te.ops.Linear(...),
    te.ops.GeLU(),
    te.ops.Linear(...),
)