NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.17k stars 1.35k forks source link

Use master weights for bfloat16 FusedAdam when master_weights=True #1731

Open cbcase opened 9 months ago

cbcase commented 9 months ago

As mentioned in #1728, the FusedAdam optimizer ignores master_weights=True for bfloat16 parameters. This PR fixes that oversight. I have confirmed that the behavior now matches a "by hand" implementation of master weights (hand-copying) along with vanilla torch.optim.AdamW on the fp32 copy.

cbcase commented 8 months ago

Ping @minitu, looks like you added this support originally -- could you take a look? Thanks

minitu commented 8 months ago

LGTM, we only looked at adding master weights for FP16 AMP at the time of the original PR. @crcrpar Could you review this as well?