NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 310 forks source link

[JAX] Made order of gated act consistent in all branches #902

Closed phu0ngng closed 3 months ago

phu0ngng commented 3 months ago

Description

This PR changes the order of the gated activation call in LayerNormMLP so that it is consistent in all conditional branches. This order is important for checkpointing as a miss-order may cause drops in training accuracy for LLama.

Type of change

Checklist:

phu0ngng commented 3 months ago

/te-ci jax