Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

LearningRateFinder creates errors for schedulers in `val` stage #20355

Open DeanLa opened 1 month ago

DeanLa commented 1 month ago

Bug description

I have a lightning module which logs the metrics val_loss, and a scheduler that monitors it

def get_plateau_scheduler(self, optimizer):
    plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',...)
    scheduler = {'scheduler': plateau_scheduler, 'interval': 'epoch'
                 'monitor': 'val_loss'} # <<<<<
    return scheduler

class MyModel(ParentModule):  # It's a lightning modsule
    def configure_optimizers(self):
        ret = super().configure_optimizers()  # I get the optimizer Here
        ret['lr_scheduler'] = self.get_plateau_scheduler(ret['optimizer'])
        return ret

I also have a list of callback on of them is LearningRateFinder I run a fit trainer = L.Trainer(logger=logger, callbacks=callbacks, **trainer_args). When the Lr Finder is in the list I get

ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['lr-AdamW', 'train_loss', 'train_loss_step', 'time/backward', 'time/train_batch', 'train_loss_epoch', 'time/train_epoch']. Condition can be set using `monitor` key in lr scheduler dict

When I remove the LR finder, training seems to work well.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- PyTorch Lightning Version (e.g., 2.4.0): #- PyTorch Version (e.g., 2.4): #- Python version (e.g., 3.12): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): ```

More info

No response

DeanLa commented 1 month ago

May be similar to #19575

DeanLa commented 1 month ago

I managed a workaround by subclassing and doing

    def on_fit_start(self, trainer, pl_module):
        return

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)
            self.log_chart()

I hope it does not affect other things I'm not aware of