https://github.com/NVIDIA/apex/pull/1723 added distopt infrastructure to support FP8 parameters in NeMo, but I found a bug with contiguous_param_buffer=True. In the non-FP8 case, the local shards of the updated params are views into the contiguous buffer. The Adam kernel outputs to the buffer, we do in-place all-gathers, and the params are ready for fprop. However, the FP8 case should use a temporary buffer since the Adam kernel doesn't support FP8. The Adam kernel outputs to a temporary FP32 buffer and we cast to FP8 in the contiguous param buffer.
https://github.com/NVIDIA/apex/pull/1723 added distopt infrastructure to support FP8 parameters in NeMo, but I found a bug with
contiguous_param_buffer=True
. In the non-FP8 case, the local shards of the updated params are views into the contiguous buffer. The Adam kernel outputs to the buffer, we do in-place all-gathers, and the params are ready for fprop. However, the FP8 case should use a temporary buffer since the Adam kernel doesn't support FP8. The Adam kernel outputs to a temporary FP32 buffer and we cast to FP8 in the contiguous param buffer.