This PRs aims to move the attributes of DistributedFusedAdam to the correct device for v1 state dict.
After loading V1 state dict, tensors in DistributedFusedAdam.["buckets"] will be on CPU device when using default checkpoint_io in NeMo. Although NeMo considers to move the optimizer state to the target CUDA device by the _optimizer_to_device in pytorch_lightning. However, it fails to do what it meant to do for DistributedFusedAdam because tensors like DistributedFusedAdam.["buckets"][0].param_remainders_shard would not be moved to the correct device when using the V1 format.
This PRs aims to move the attributes of
DistributedFusedAdam
to the correct device for v1 state dict.After loading V1 state dict, tensors in
DistributedFusedAdam.["buckets"]
will be on CPU device when using default checkpoint_io in NeMo. Although NeMo considers to move the optimizer state to the target CUDA device by the _optimizer_to_device inpytorch_lightning
. However, it fails to do what it meant to do for DistributedFusedAdam because tensors likeDistributedFusedAdam.["buckets"][0].param_remainders_shard
would not be moved to the correct device when using the V1 format.This PR aims to fix it.