Open leenachennuru opened 1 month ago
For better or for worse, I think "AdamW" now refers to the LR-coupled version. In addition to PyTorch and JAX, I see this formulation in Keras (and therefore TensorFlow), PaddlePaddle, and MXNet. If we implement a LR-decoupled variant, we should give it a new name or make it an opt-in option so we don't confuse users.
There has been a lot of discussion in other frameworks:
It seems PyTorch deliberately made the decision to use the LR-coupled variant, and that's percolated to the entire ecosystem.
Allowing the user to invoke the fully decoupled version via either option (opt-in or another name) would be helpful. Couple more references on the potential utility of independent WD below.
Describe the bug
AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.
The implementation computes the updates as
$w{t} = (1- \eta{\text{effective}} \lambda) w{t-1} - \eta{\text{effective}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$
where $\eta_{\text{effective}} = \etat \eta{\text{max}}$ with $\etat$ denoting the scheduler and $\eta{\text{max}}$ the max/base LR.
This clearly couples LR and WD and is not in line with the paper which proposes to compute the updates as
$w_{t} = (1- \etat \lambda) w{t-1} - \etat \eta{\text{max}} {\hat{m}_t} / {\sqrt{\hat{v}_t} + \epsilon}$
For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix: $\lambda = (\eta{\text{effective}} / \eta{\text{max}}) \lambda$ with updates: $w{t} = (1- \lambda) w{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$.
Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.