NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

Tkurth/mplamb fixed #1684

Closed azrael417 closed 1 year ago

azrael417 commented 1 year ago

This MR fixes the Mixed Precision LAMB optimizer: note that param_groups is not set up before the optimizer module init was called. Therefore I swapped the order of obtaining device info and the super module init around. I tested it and it seems to work. This is critical for MLPerf HPC, please review and merge asap.