This PR addresses an issue introduced in the NeMo distributed checkpoint format (see https://github.com/NVIDIA/NeMo/pull/7116 and https://github.com/NVIDIA/NeMo/pull/7281). When we load a distributed checkpoint, we first create a state dict, replace its values with checkpoint data, and then load the state dict. The current distributed optimizer state dict function gathers all the data on rank 0 and returns Nones on other ranks, so we were getting errors on non-root ranks. This PR changes the state dict behavior so it all-gathers the data and returns identical state dicts on all ranks.
This PR addresses an issue introduced in the NeMo distributed checkpoint format (see https://github.com/NVIDIA/NeMo/pull/7116 and https://github.com/NVIDIA/NeMo/pull/7281). When we load a distributed checkpoint, we first create a state dict, replace its values with checkpoint data, and then load the state dict. The current distributed optimizer state dict function gathers all the data on rank 0 and returns
None
s on other ranks, so we were getting errors on non-root ranks. This PR changes the state dict behavior so it all-gathers the data and returns identical state dicts on all ranks.