NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

AdamW implementation does not truly decouple learning rate and weight decay #1849

Open leenachennuru opened 1 month ago

leenachennuru commented 1 month ago

Describe the bug

AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.

The implementation computes the updates as

$w{t} = (1- \eta{\text{effective}} \lambda) w{t-1} - \eta{\text{effective}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

where $\eta_{\text{effective}} = \etat \eta{\text{max}}$ with $\etat$ denoting the scheduler and $\eta{\text{max}}$ the max/base LR.

This clearly couples LR and WD and is not in line with the paper which proposes to compute the updates as

$w_{t} = (1- \etat \lambda) w{t-1} - \etat \eta{\text{max}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$

For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix: $\lambda = (\eta{\text{effective}} / \eta{\text{max}}) \lambda$ with updates: $w{t} = (1- \lambda) w{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$.

Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.

  1. Mosaic ML Library
  2. Optimi
  3. Paper: How to set AdamW's weight decay as you scale model and dataset size
  4. Fabian Schaipp's blog
timmoon10 commented 1 month ago

For better or for worse, I think "AdamW" now refers to the LR-coupled version. In addition to PyTorch and JAX, I see this formulation in Keras (and therefore TensorFlow), PaddlePaddle, and MXNet. If we implement a LR-decoupled variant, we should give it a new name or make it an opt-in option so we don't confuse users.

There has been a lot of discussion in other frameworks:

It seems PyTorch deliberately made the decision to use the LR-coupled variant, and that's percolated to the entire ecosystem.

leenachennuru commented 1 month ago

Allowing the user to invoke the fully decoupled version via either option (opt-in or another name) would be helpful. Couple more references on the potential utility of independent WD below.

  1. Small-scale proxies for large-scale Transformer training instabilities
  2. A Large-Scale Exploration of µ-Transfer
  3. u-μP: The Unit-Scaled Maximal Update Parametrization