konstmish / prodigy

The Prodigy optimizer and its variants for training neural networks.
MIT License
296 stars 17 forks source link

Possible to marry Prodigy and AdamW? #11

Open askerlee opened 8 months ago

askerlee commented 8 months ago

Been using Prodigy for a few days and honestly I'm very impressed by its performance. Especially, I can set a large learning rate (lr=1, d_coef=10) without blowing up the gradients. However, the final stage of the learning seems to be suboptimal and hard to tune. I need to use a scheduler to reduce the final LR to 0.1, but not sure if this may reduce the effectiveness of Prodigy. After the training bootstraps, seems Prodigy and AdamW become somewhat complementary in their advantages/limitations. Wonder if AdamW can make use of the statistics estimated by Prodigy, if possible, would it be advantageous to continue the training using AdamW instead?

askerlee commented 8 months ago

I implemented a very rough version of this idea. The pseudocode:

# initialize prodigy instance and adamw instance with the same params and pass the adamw instance to prodigy.
class ProdigyAdamW(torch.optim.Optimizer):
def __init__(self, params, adamw_optimizer=None, prodigy_weight=0.5, ...):
     ......
    # Reduce the LR to leave room for joint update with AdamW.
    defaults = dict(lr=lr * prodigy_weight, betas=betas, beta3=beta3,
                    eps=eps, weight_decay=weight_decay,
                    d=d0, d0=d0, d_max=d0,
                    d_numerator=0.0, d_coef=d_coef,
                    k=0, growth_rate=growth_rate,
                    use_bias_correction=use_bias_correction,
                    decouple=decouple, safeguard_warmup=safeguard_warmup,
                    fsdp_in_use=fsdp_in_use)

def step(self, closure=None):
     loss = self.prodigy_step(closure)
     # NOTE: weight decay is not handled, which changes p.grad. 
     # A better handling will be save p.grad before prodigy update, and restore p.grad 
     # after prodigy update and before adam update.
     self.adamw_optimizer.step()
     return loss

Since it's a joint update by Prodigy and AdamW, I halve the LRs of both Prodigy and AdamW. I'm still evaluating this hybrid optimizer to see if it brings any benefits 😺

askerlee commented 8 months ago

An update: I've tried the hybrid optimizer on my method and seems it's not better than using prodigy alone. The best Prodigy setting among my experiments: for the first half of training, set LR=1. For the second half, use PolynomialLR with power=1 (i.e., linearly decreases LR), and total_iters=second_half_steps * 1.1, so that the final LR will be 0.09 (if total_iters=second_half_steps, then final LR will be 0, which seems suboptimal). The schedule above can be implemented with a SequentialLR on top of a ConstantLR and a PolynomialLR.

askerlee commented 8 months ago

Update 2:

I found using CyclicLR with two cycles (the solid red line) yields more stable and slightly better results than using only one decay cycle (the dashed red line). My training method consists of multiple losses/regularizations. Seems a high LR is mostly beneficial to the reconstruction loss (=subject fidelity), and a low LR is beneficial to other regularization terms. Therefore cycling between the two balances these two types of losses.

The dashed red line is the schedule above, which sometimes leads to better subject fidelity but sometimes it's much worse. I guess it's because at the later stage, it's kept at low LRs, thus the learned reconstruction abilities (subject fidelity) are sometimes slightly forgotten. Running two cycles may give the model another chance to relearn the concept.

image

adefazio commented 8 months ago

Thanks for the details! I haven't tried a cyclic schedule much, I usually use just a simple linear decay, similar to the dashed line but without the offset.

askerlee commented 8 months ago

@adefazio Hope you find it helpful! By the offset, do you mean the final LR = 0.09, or the initial 500 steps of LR=1? The initial 500 steps are to let Prodigy better estimate the D function (Prodigy loves continuous LR; constant LR is even better), which I'd love to call "warm up steps" 😄 (it's contrary to AdamW whose warm up steps are of small LRs so that AdamW can estimate the grad statistics more accurately without going too far in a wrong direction)

adefazio commented 8 months ago

@askerlee The initial 500 steps. That explanation makes sense though, they could help with estimating the learning rate.

rafstahelin commented 5 months ago

@adefazio Hope you find it helpful! By the offset, do you mean the final LR = 0.09, or the initial 500 steps of LR=1? The initial 500 steps are to let Prodigy better estimate the D function (Prodigy loves continuous LR; constant LR is even better), which I'd love to call "warm up steps" 😄 (it's contrary to AdamW whose warm up steps are of small LRs so that AdamW can estimate the grad statistics more accurately without going too far in a wrong direction)

Could you share a json of your training for the args. My dataset of photographic humans is roughly of 30 images for subject and 500 for style. I did some initial testing of prodigy but love to have another go at it with your template

JDvorak commented 5 months ago

https://github.com/facebookresearch/schedule_free Along similar lines, there's a schedule-free AdamW here that seems interesting and comes at this from another angle.