Closed jayakrishnaanvesh closed 11 months ago
def load_state_dict(self, state_dict: dict) -> None: https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L2317
In case the distributed size change while loading from the checkpoint, we need a feature to support sharding the optimizer states based in the distributed size and support loading state dict appropriately for each distributed rank.
def load_state_dict(self, state_dict: dict) -> None: https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/distributed_fused_adam.py#L2317
In case the distributed size change while loading from the checkpoint, we need a feature to support sharding the optimizer states based in the distributed size and support loading state dict appropriately for each distributed rank.