This PR does some refactoring that will enable distributed optimizer support for FP8 parameters in NeMo. It adds the option to do parameter all-gathers in integer dtypes and adds two member functions - _check_params_shard_dtypes and _param_copy_fragments - to handle casting into and out of the all-gather buffer. For now these functions will either do a direct cast for floating-point dtypes or copy the most significant bytes for other dtypes. I plan to override these functions in the NeMo derived class so that it casts to FP8, performs the all-gather in UINT8, and unpacks into a custom FP8 tensor class.
This PR does some refactoring that will enable distributed optimizer support for FP8 parameters in NeMo. It adds the option to do parameter all-gathers in integer dtypes and adds two member functions -
_check_params_shard_dtypes
and_param_copy_fragments
- to handle casting into and out of the all-gather buffer. For now these functions will either do a direct cast for floating-point dtypes or copy the most significant bytes for other dtypes. I plan to override these functions in the NeMo derived class so that it casts to FP8, performs the all-gather in UINT8, and unpacks into a custom FP8 tensor class.This PR depends on https://github.com/NVIDIA/apex/pull/1719 and https://github.com/NVIDIA/apex/pull/1721.