NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

Move to the correct device for v1 state dict #1783

Closed acphile closed 7 months ago

acphile commented 8 months ago

This PRs aims to move the attributes of DistributedFusedAdam to the correct device for v1 state dict.

After loading V1 state dict, tensors in DistributedFusedAdam.["buckets"] will be on CPU device when using default checkpoint_io in NeMo. Although NeMo considers to move the optimizer state to the target CUDA device by the _optimizer_to_device in pytorch_lightning. However, it fails to do what it meant to do for DistributedFusedAdam because tensors like DistributedFusedAdam.["buckets"][0].param_remainders_shard would not be moved to the correct device when using the V1 format.

This PR aims to fix it.

ericharper commented 7 months ago

@timmoon10, can you review this one?

acphile commented 7 months ago

Thank you! Could it be merged?