NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

Support scaled optimizer state in distributed Adam optimizer #1771

Closed timmoon10 closed 9 months ago

timmoon10 commented 10 months ago

This PR adds basic support for scaled optimizer state as discussed in the MS-AMP paper. The idea is that per-tensor scaling factors along with FP16/FP8 optimizer state results in lower memory usage than FP32 optimizer state with no degradation in convergence. This implementation is not quite the same as the MS-AMP FP8 optimizer since it only uses FP16 optimizer state and uses per-parameter-fragment scaling factors rather than per-parameter. It is a preliminary implementation and its performance could be improved with custom kernels (e.g. kernel to compute scaling factors, fused kernel with FP16-FP32 casts and Adam step).

In the process of debugging, I've also made some other performance optimizations and bugfixes: