A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This PR changes the order of the gated activation call in LayerNormMLP so that it is consistent in all conditional branches.
This order is important for checkpointing as a miss-order may cause drops in training accuracy for LLama.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Description
This PR changes the order of the gated activation call in
LayerNormMLP
so that it is consistent in all conditional branches. This order is important for checkpointing as a miss-order may cause drops in training accuracy for LLama.Type of change
Checklist: