NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.17k stars 1.35k forks source link

FusedAdam doesn't allocate master weights for bfloat16 #1728

Open cbcase opened 9 months ago

cbcase commented 9 months ago

If you look at https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py#L184, the FusedAdam code doesn't allocate master weights when the parameter dtype is bfloat16, even if you set master_weights=True (and subsequently no master weights are passed to the kernel at https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py#L235).

Is there a specific reason for this, or is it simply an oversight?

(Seeing as bfloat has fewer mantissa bits than fp16, fp32 master weights are even more important for bfloat16 -- though, really, they are a necessity for both.)

Aidyn-A commented 8 months ago

@crcrpar, do you remember why we skip master_weights for bfloat16?

crcrpar commented 8 months ago

@crcrpar, do you remember why we skip master_weights for bfloat16?

I'm unsure but I vaguely remember there wasn't master weights usage in fused adam so it feels like more of a fallout when fp16 - fp32 things were added