Closed jondeuce closed 2 weeks ago
This is a really interesting technical point. I've seen it implemented several different ways in different libraries also. My feeling is that it's unlikely to make much of a difference in practice for any of the forms. Although the differing behavior under with weight decay at the very beginning could conceivably result in ending up in different parts of the parameter space. I am open to changing the behavior if there is any empirical evidence that one form is different than the other. I should definitely note in the paper that the implemented form is slightly different than the described form, or perhaps change the pseudo-code so they match.
Okay great, that all aligns with my thinking. I don't have any empirical evidence in either direction; I switched my implementation to use the standard method as described in the paper and saw no discernible difference. My linear warmup periods are typically several thousand iterations, though, so probably beta2^t
effects are not so important since they decay away before the learning rate is very large anyways.
There are of course several equivalent-ish ways that
beta2
bias correction can be performed. The paper describes the standard method, which is slightly different than the various implementation(s) here, so I wanted to enumerate the options and see if there's a reason to prefer one way or another in the authors' opinion:The standard method described in the paper: divide the EMA of the second moments by
1 - beta2^t
. This can also be implemented using an iteration-dependentbeta2
viabeta2_t = beta2 * (1 - beta2^(t-1)) / (1 - beta2^t)
.In this library, a
sqrt(1 - beta2^t)
correction is absorbed into the learning rate. This is a bit different than above, though it is approximately equivalent iflambda == 0
, exactly equivalent iflambda == eps == 0
, and can be interpreted as "weight decay warmup" whenlambda > 0
. This also impacts thelr_max
andweight_sum
state variables, and therefore theckp1
schedule is slightly different:https://github.com/facebookresearch/schedule_free/blob/38109d045a29b8463d496589b622d1771a1cd5ac/schedulefree/adamw_schedulefree.py#L126-L137
sqrt(1 - beta2^t)
factor is absorbed into the learning rate but only after computingckp1
, i.e. the same as above but withckp1
matching the standard method (1):https://github.com/facebookresearch/schedule_free/blob/38109d045a29b8463d496589b622d1771a1cd5ac/schedulefree/algoperf/external_tuning/schedule_free_adamw/submission.py#L80-L95
Intuitively I would prefer to use the standard method (1), and probably it doesn't really matter which in practice since
beta2^t
approaches zero exponentially quickly (though this can still take a few thousand iterations with defaultbeta2 = 0.999
). But maybe you have found otherwise?