Open cbcase opened 9 months ago
@crcrpar, do you remember why we skip master_weights
for bfloat16
?
@crcrpar, do you remember why we skip
master_weights
forbfloat16
?
I'm unsure but I vaguely remember there wasn't master weights usage in fused adam so it feels like more of a fallout when fp16 - fp32 things were added
If you look at https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py#L184, the FusedAdam code doesn't allocate master weights when the parameter dtype is bfloat16, even if you set
master_weights=True
(and subsequently no master weights are passed to the kernel at https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py#L235).Is there a specific reason for this, or is it simply an oversight?
(Seeing as bfloat has fewer mantissa bits than fp16, fp32 master weights are even more important for bfloat16 -- though, really, they are a necessity for both.)