A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
te.distributed.checkpoint + torch.compile causes errors with user context functions and torch.amp.autocast() support in our checkpointing. The native PyTorch checkpoint wrapper avoids these issues via @torch._disable_dynamo for the checkpoint wrapper. This PR applies the same decorator to our own checkpoint.
Deprecation warnings were also discovered with PyTorch autocast API in the course of testing the fix. This PR also incorporates the minor changes required to address those deprecations.
Fixes #890
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
Added @torch._disable_dynamo decorator to te.distributed.checkpoint().
Fixed PyTorch deprecations for autocast API used in te.distributed._get_active_autocast_contexts() .
Description
te.distributed.checkpoint
+torch.compile
causes errors with user context functions andtorch.amp.autocast()
support in our checkpointing. The native PyTorch checkpoint wrapper avoids these issues via@torch._disable_dynamo
for the checkpoint wrapper. This PR applies the same decorator to our own checkpoint.Deprecation warnings were also discovered with PyTorch autocast API in the course of testing the fix. This PR also incorporates the minor changes required to address those deprecations.
Fixes #890
Type of change
Changes
@torch._disable_dynamo
decorator tote.distributed.checkpoint()
.te.distributed._get_active_autocast_contexts()
.Checklist: