facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.78k stars 62 forks source link

Where should beta2 bias correction be incorporated? #38

Open jondeuce opened 1 month ago

jondeuce commented 1 month ago

There are of course several equivalent-ish ways that beta2 bias correction can be performed. The paper describes the standard method, which is slightly different than the various implementation(s) here, so I wanted to enumerate the options and see if there's a reason to prefer one way or another in the authors' opinion:

  1. The standard method described in the paper: divide the EMA of the second moments by 1 - beta2^t. This can also be implemented using an iteration-dependent beta2 via beta2_t = beta2 * (1 - beta2^(t-1)) / (1 - beta2^t).

  2. In this library, a sqrt(1 - beta2^t) correction is absorbed into the learning rate. This is a bit different than above, though it is approximately equivalent if lambda == 0, exactly equivalent if lambda == eps == 0, and can be interpreted as "weight decay warmup" when lambda > 0. This also impacts the lr_max and weight_sum state variables, and therefore the ckp1 schedule is slightly different:

https://github.com/facebookresearch/schedule_free/blob/38109d045a29b8463d496589b622d1771a1cd5ac/schedulefree/adamw_schedulefree.py#L126-L137

  1. Here the sqrt(1 - beta2^t) factor is absorbed into the learning rate but only after computing ckp1, i.e. the same as above but with ckp1 matching the standard method (1):

https://github.com/facebookresearch/schedule_free/blob/38109d045a29b8463d496589b622d1771a1cd5ac/schedulefree/algoperf/external_tuning/schedule_free_adamw/submission.py#L80-L95

Intuitively I would prefer to use the standard method (1), and probably it doesn't really matter which in practice since beta2^t approaches zero exponentially quickly (though this can still take a few thousand iterations with default beta2 = 0.999). But maybe you have found otherwise?

adefazio commented 1 month ago

This is a really interesting technical point. I've seen it implemented several different ways in different libraries also. My feeling is that it's unlikely to make much of a difference in practice for any of the forms. Although the differing behavior under with weight decay at the very beginning could conceivably result in ending up in different parts of the parameter space. I am open to changing the behavior if there is any empirical evidence that one form is different than the other. I should definitely note in the paper that the implemented form is slightly different than the described form, or perhaps change the pseudo-code so they match.

jondeuce commented 1 month ago

Okay great, that all aligns with my thinking. I don't have any empirical evidence in either direction; I switched my implementation to use the standard method as described in the paper and saw no discernible difference. My linear warmup periods are typically several thousand iterations, though, so probably beta2^t effects are not so important since they decay away before the learning rate is very large anyways.