jettify / pytorch-optimizer

torch-optimizer -- collection of optimizers for Pytorch
Apache License 2.0
3.01k stars 294 forks source link

Example of how to use scheduler #404

Open brando90 opened 2 years ago

brando90 commented 2 years ago

I'd like to use the Adafactor scheduler as the hugging face code has (but their code does not work for CNNS).

questions as follow: a) How do I use schedulers with pytorch-optimizer? b) Can we add an example in the readme? c) does adafactor have it's own scheduler here?

related:

brando90 commented 2 years ago

(btw, might it be useful to open the discussion feature from github?)

brando90 commented 2 years ago

would this work?

from torch.optim.lr_scheduler import LambdaLR

class AdafactorSchedulerUU(LambdaLR):
    """
    TODO: I think this should work for the torch_optimizer library...
    """

    def __init__(self, optimizer, initial_lr=0.0):
        def lr_lambda(_):
            return initial_lr

        for group in optimizer.param_groups:
            group["initial_lr"] = initial_lr

        super().__init__(optimizer, lr_lambda)
        for group in optimizer.param_groups:
            del group["initial_lr"]

    def get_lr(self):
        opt = self.optimizer
        lrs = [
            opt._get_lr(group, opt.state[group["params"][0]])
            for group in opt.param_groups
            if group["params"][0].grad is not None
        ]
        if len(lrs) == 0:
            lrs = self.base_lrs  # if called before stepping
        return lrs