NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.4k forks source link

Distributed optimizer infrastructure for FP8 parameters #1723

Closed timmoon10 closed 1 year ago

timmoon10 commented 1 year ago

This PR does some refactoring that will enable distributed optimizer support for FP8 parameters in NeMo. It adds the option to do parameter all-gathers in integer dtypes and adds two member functions - _check_params_shard_dtypes and _param_copy_fragments - to handle casting into and out of the all-gather buffer. For now these functions will either do a direct cast for floating-point dtypes or copy the most significant bytes for other dtypes. I plan to override these functions in the NeMo derived class so that it casts to FP8, performs the all-gather in UINT8, and unpacks into a custom FP8 tensor class.

This PR depends on https://github.com/NVIDIA/apex/pull/1719 and https://github.com/NVIDIA/apex/pull/1721.