NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.17k stars 1.35k forks source link

Return distributed optimizer checkpoint on all ranks #1719

Closed timmoon10 closed 10 months ago

timmoon10 commented 10 months ago

This PR addresses an issue introduced in the NeMo distributed checkpoint format (see https://github.com/NVIDIA/NeMo/pull/7116 and https://github.com/NVIDIA/NeMo/pull/7281). When we load a distributed checkpoint, we first create a state dict, replace its values with checkpoint data, and then load the state dict. The current distributed optimizer state dict function gathers all the data on rank 0 and returns Nones on other ranks, so we were getting errors on non-root ranks. This PR changes the state dict behavior so it all-gathers the data and returns identical state dicts on all ranks.