jettify / pytorch-optimizer

torch-optimizer -- collection of optimizers for Pytorch
Apache License 2.0
3.04k stars 300 forks source link

spot bug in SGDW implementation (weight decay part) #454

Open Leiay opened 2 years ago

Leiay commented 2 years ago

Hi,

I was using the SGDW implementation in this repo, and I wonder if anything is wrong with this line:

https://github.com/jettify/pytorch-optimizer/blob/910b414565427f0a66e20040475e7e4385e066a5/torch_optimizer/sgdw.py#L121

Let weight decay be $\lambda$ and learning rate be $\mu_t$. If I understand it correctly, this line of code update weight decay with $$\theta_t \leftarrow \tilde{\theta}_t - \lambda \mu_t$$ where (follow the notation in the paper)

$$\tilde{\theta}_t \leftarrow \theta_{t-1} - m_t$$

But it should be

$$ \begin{aligned} \theta{t-1} &\leftarrow \theta{t-1} \cdot (1 - \lambda \mu_t) \ \thetat &\leftarrow \theta{t-1} - m_t \end{aligned} $$

as in the paper:

image

This result in poor performance of training compared to SGD with the same set of optimization hyper-parameter.

Thanks!

Regards, Liu