This PR adds logic so that the parameters can be configured with different dtypes for the grad reduce-scatters and param all-gathers. I have two NeMo use-cases in mind:
For GPT, most grads can be reduced in BF16 but embedding grads need to be reduced in FP32 to avoid learning issues.
For FP8 support, weight matrices can be stored in FP8 while most other parameters (e.g. biases, layernorm params, embeddings) are in BF16. We would like to handle FP8 and BF16 param all-gathers in the same optimizer.
This PR adds logic so that the parameters can be configured with different dtypes for the grad reduce-scatters and param all-gathers. I have two NeMo use-cases in mind:
This also includes changes from https://github.com/NVIDIA/apex/pull/1719, which returns the state dict on all ranks and not just rank 0. We can either merge that first and rebase, or merge this and close https://github.com/NVIDIA/apex/pull/1719.