NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Apache License 2.0
1.79k stars 296 forks source link

Selective Activation Checkpointing with LayerNormMLP #623

Open denizokt opened 7 months ago

denizokt commented 7 months ago

Hi all,

I was wondering whether it is possible to do selective activation checkpointing with the LayerNormMLP where we only recompute FFN1 and not FFN2, therefore not having to save the ffn1_out and gelu_out activations (the largest memory activations).

This has been done in OPT, https://github.com/facebookresearch/metaseq/blob/f7ffa5fd61cf90f498a36d365c13dd7f1a912ff7/metaseq/modules/sequence_parallel_transformer_layer.py#L250C20-L250C33 so I wonder if it is possible to do in TransformerEngine, because it would be awesome to use it with FP8!

Thank you!

ptrendx commented 3 months ago

@sudhakarsingh27 This is basically the same as what we discussed with improving the checkpoint logic to allow for the early stopping.