NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[PyTorch] Disabling TorchDynamo for TE activation checkpoint wrapper #894

Closed denera closed 3 weeks ago

denera commented 4 weeks ago

Description

te.distributed.checkpoint + torch.compile causes errors with user context functions and torch.amp.autocast() support in our checkpointing. The native PyTorch checkpoint wrapper avoids these issues via @torch._disable_dynamo for the checkpoint wrapper. This PR applies the same decorator to our own checkpoint.

Deprecation warnings were also discovered with PyTorch autocast API in the course of testing the fix. This PR also incorporates the minor changes required to address those deprecations.

Fixes #890

Type of change

Changes

Checklist:

denera commented 3 weeks ago

/te-ci pytorch