Closed nestordemeure closed 2 years ago
There is no particular reason why we currently use the traditional weight decay form. I haven't experimented with AdamW style decay with MADGRAD yet. If you find it works on your problem, let me know!
I tried the following form on a personnal dataset (regression on tabular data):
updated_parameters -= (1. - beta) * power(learning_rate, 2./3.) * weight_decay * param
And got better result than what I had with the default weight decay method (but it might be due to parameter tuning or not generalize to other dataset).
The multiplication by (1. - beta) * power(learning_rate, 2./3.)
rather than learning_rate
is an effort to use the actual step size rather than the raw learning rate (lr / cbrt(lr) = lr^(2/3)
). With that scaling the weight decay I had previously tuned for Adam seemed to work best which is practical.
Yes that's a good idea in terms of weighting the learning rate. I had considered that however if you use a changing learning rate over time it will result in odd behavior. I.e. if you decrease the learning rate 10x like for ImageNet training midway through training, it won't result in 10x decrease in practice. But it works for scaling the initial learning rate.
A friend just test both the default weight decay and a AdamW style weight decay for picture classication. He found that, using the default weight decay, he got no improvements (even with low values) whereas he had his best test score so far with a AdamW-style weight decay.
Overall it seems worth using.
I will look into adding the adamw style weight decay as an option, thanks for the discussion and results!
updated_parameters -= (1. - beta) * power(learning_rate, 2./3.) * weight_decay * param
@nestordemeure Would you mind helping me out with trying to implement this in the pytorch version here? I believe I've got everything in place, but what would beta be here? So far, I couldn't find an appropriate equivalent variable in the original MADGRAD implementation.
beta
is momentum
in this implementation (here). I called it beta in my own code to stay consistent with the usual naming scheme.
Perfect, thanks for the quick reply!
I'm going to add adamw style averaging to the implementation this week as it seems popular based on the comments here.
Any updates on this? :)
I'm looking into this now, it's not actually clear what's the correct way to do decoupled weight decay within a dual averaging framework. I don't want to commit code until I'm sure it's correct.
I'm currently testing adding an update similar to:
p.data.div_(((lr)**(2/3))*(k+1)*decay+1)
after the p.data update, with line 119 removed. This is an explicit type of weight decay, slightly different from AdamW but better suited to the dual averaging framework. I need to run some experiments and make sure it works on the standard test problems before I commit the code. It needs to handle changing learning rate during optimization, so I'm actually using an accumulating sum of learning rates.
It's in branch decoupleddecay if you want to try it out.
I've switched back to the simplest essentially @russelldc 's suggestion but without the lr 2/3 power. The 2/3 power gives the correct scaling at the beginning of training, however after future learning rate decreases it will scale the decay in the wrong way. Best to adjust your decay before hand using the 2/3 correction than to have it in the code.
This is perhaps a silly question, but why is decouple_decay
not added to the defaults
dict?
That would probably be a better way to do it, I'll make that change when I have the time this week.
Hello,
While translating your optimizer to Flax (here), I noticed that you are using a traditional weight decay were you add the weight decay to the gradient (here in your implementation):
Rather than an AdamW style weight decay (which, I believe, is now the default for most optimizers) were you would subtract the weight decay time the learning rate just before returning the parameters:
Is there a particular reason for that decision ?