NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

Scale optimizer state with updated distributed size #1707

Closed jayakrishnaanvesh closed 11 months ago

jayakrishnaanvesh commented 11 months ago

def load_state_dict(self, state_dict: dict) -> None: https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L2317

In case the distributed size change while loading from the checkpoint, we need a feature to support sharding the optimizer states based in the distributed size and support loading state dict appropriately for each distributed rank.