A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
TE modules implement get_extra_state/set_extra_state in order to include FP8 state in checkpoints. We currently pickle the FP8 state and store in a io.BtyesIO object, but this is problematic because PyTorch makes no guarantees if the extra state is not a torch.Tensor. This has resulted in problems with ONNX export and Hugging Face Transformers.
https://github.com/NVIDIA/TransformerEngine/pull/363 changed from storing the extra state in a torch.Tensor to io.BytesIO in order to reduce the overhead from GPU-CPU memory transfers. This PR restores the original torch.Tensor format, but performs the memory transfers asynchronously to reduce overhead. It's similar to the approach used in the operation-based API (https://github.com/NVIDIA/TransformerEngine/pull/1063). It should be backward compatible and I've been able to load existing checkpoints. The attention docs mention extra state, but I don't think this PR affects it.
Description
TE modules implement
get_extra_state
/set_extra_state
in order to include FP8 state in checkpoints. We currently pickle the FP8 state and store in aio.BtyesIO
object, but this is problematic because PyTorch makes no guarantees if the extra state is not atorch.Tensor
. This has resulted in problems with ONNX export and Hugging Face Transformers.https://github.com/NVIDIA/TransformerEngine/pull/363 changed from storing the extra state in a
torch.Tensor
toio.BytesIO
in order to reduce the overhead from GPU-CPU memory transfers. This PR restores the originaltorch.Tensor
format, but performs the memory transfers asynchronously to reduce overhead. It's similar to the approach used in the operation-based API (https://github.com/NVIDIA/TransformerEngine/pull/1063). It should be backward compatible and I've been able to load existing checkpoints. The attention docs mention extra state, but I don't think this PR affects it.Fixes https://github.com/NVIDIA/TransformerEngine/issues/1317.
Type of change
Changes
io.BytesIO
Checklist: