Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.39k stars 3.38k forks source link

Turning SWA on makes scheduler lr change to epoch, instead of batch [with colab ex] #6963

Closed felipemello1 closed 3 years ago

felipemello1 commented 3 years ago

🐛 Bug

Below is my optimizer/scheduler code. If my trainer has stochastic_weight_avg=True, then my learning rate is shown below, in green, and I get the warning:

_/home/felipe/anaconda3/envs/ML_38_new/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:68: UserWarning: Swapping lr_scheduler <torch.optim.lr_scheduler.OneCycleLR object at 0x7fa445c76ee0> for <torch.optim.swautils.SWALR object at 0x7fa445e2d190> warnings.warn(*args, **kwargs)

If stochastic_weight_avg=False, then I get the appropriate learning rate scheduler (pink).

image

What seems to me is that when stochastic_weight_avg=True, there is some conflict related to updating per batch or per epoch

def configure_optimizers(self):
        params = list(self.named_parameters())
        def is_backbone(n): return 'encoder' in n

        grouped_parameters = [
            {'params': [p for n, p in params if is_backbone(n)], 'lr': cfg.max_lr/cfg.encoder_lr_frac},
            {'params': [p for n, p in params if not is_backbone(n)], 'lr': cfg.max_lr},
        ]

        optimizer = MADGRAD(grouped_parameters, lr=cfg.max_lr, weight_decay=cfg.wd)

        if cfg.scheduler_type == 'onecycle':
            scheduler_fn = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                                max_lr=[cfg.max_lr/cfg.encoder_lr_frac, cfg.max_lr],
                                                                epochs=self.epochs,
                                                                steps_per_epoch = len(self.loaders_dict['train']))

        scheduler = {'scheduler': scheduler_fn,
                     'name': cfg.scheduler_type,
                     'frequency': 1,
                     'interval': 'step',
                     "monitor": 'train_loss'}#len(self.loaders_dict['train']))

        return [optimizer], [scheduler]

To Reproduce

https://gist.github.com/fmellomascarenhas/7e53efbacaafd8769088d58574e73cd5

carmocca commented 3 years ago

Hi!

Currently, the SWA callback only supports interval=='epoch'. I'll update the docs and warning to be more clear about this.