FusedAdamSWA interface was loosely typed and error-prone
The training critical path of FusedAdamSWA (i.e., its step function) could contain unnecessary GPU-host sync when grad_clip_scale is set to a non-CUDA-tensor variable
FusedAdamSWA didn't have any unit test
What?
Encapsulated FusedAdamSWA math types and internal numerical type into Python enumerations to improve type robustness and readability
Accept grad_clip_scale as either a tensor or a number, for the latter case we move it to GPU in a non-blocking manner to eliminate a GPU-host sync
Add unit test to guarentee numerical correctness and demostrate usage
Why?
What?