NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.33k stars 1.39k forks source link

Distributed optimizer support for contiguous param buffer with FP8 params #1749

Closed timmoon10 closed 10 months ago

timmoon10 commented 10 months ago

https://github.com/NVIDIA/apex/pull/1723 added distopt infrastructure to support FP8 parameters in NeMo, but I found a bug with contiguous_param_buffer=True. In the non-FP8 case, the local shards of the updated params are views into the contiguous buffer. The Adam kernel outputs to the buffer, we do in-place all-gathers, and the params are ready for fprop. However, the FP8 case should use a temporary buffer since the Adam kernel doesn't support FP8. The Adam kernel outputs to a temporary FP32 buffer and we cast to FP8 in the contiguous param buffer.

timmoon10 commented 10 months ago

👍 Done