NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.17k stars 1.35k forks source link

Distributed optimizer support for multiple dtypes #1721

Closed timmoon10 closed 10 months ago

timmoon10 commented 10 months ago

This PR adds logic so that the parameters can be configured with different dtypes for the grad reduce-scatters and param all-gathers. I have two NeMo use-cases in mind:

This also includes changes from https://github.com/NVIDIA/apex/pull/1719, which returns the state dict on all ranks and not just rank 0. We can either merge that first and rebase, or merge this and close https://github.com/NVIDIA/apex/pull/1719.