NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.99k stars 331 forks source link

[PyTorch] Store module extra state in tensor #1335

Open timmoon10 opened 1 week ago

timmoon10 commented 1 week ago

Description

TE modules implement get_extra_state/set_extra_state in order to include FP8 state in checkpoints. We currently pickle the FP8 state and store in a io.BtyesIO object, but this is problematic because PyTorch makes no guarantees if the extra state is not a torch.Tensor. This has resulted in problems with ONNX export and Hugging Face Transformers.

https://github.com/NVIDIA/TransformerEngine/pull/363 changed from storing the extra state in a torch.Tensor to io.BytesIO in order to reduce the overhead from GPU-CPU memory transfers. This PR restores the original torch.Tensor format, but performs the memory transfers asynchronously to reduce overhead. It's similar to the approach used in the operation-based API (https://github.com/NVIDIA/TransformerEngine/pull/1063). It should be backward compatible and I've been able to load existing checkpoints. The attention docs mention extra state, but I don't think this PR affects it.

Fixes https://github.com/NVIDIA/TransformerEngine/issues/1317.

Type of change

Changes

Checklist:

timmoon10 commented 1 week ago

/te-ci pytorch L0 L1