NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.79k stars 296 forks source link

Selective Activation Checkpointing with LayerNormMLP #623

Open denizokt opened 7 months ago

denizokt commented 7 months ago

Hi all,

I was wondering whether it is possible to do selective activation checkpointing with the LayerNormMLP where we only recompute FFN1 and not FFN2, therefore not having to save the ffn1_out and gelu_out activations (the largest memory activations).

This has been done in OPT, https://github.com/facebookresearch/metaseq/blob/f7ffa5fd61cf90f498a36d365c13dd7f1a912ff7/metaseq/modules/sequence_parallel_transformer_layer.py#L250C20-L250C33 so I wonder if it is possible to do in TransformerEngine, because it would be awesome to use it with FP8!

Thank you!

ptrendx commented 3 months ago

@sudhakarsingh27 This is basically the same as what we discussed with improving the checkpoint logic to allow for the early stopping.