Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28k stars 3.36k forks source link

Checkpoint callback run before validation step - stale or none monitor values considered for validation metrics #20185

Open PheelaV opened 1 month ago

PheelaV commented 1 month ago

Bug description

I am doing iterative training with check_val_every_n_epoch=None and (example values)val_check_interval=10 on my trainer and with the matched argument on ModelCheckpoint every_n_train_steps=10.

e.g.

  checkpoint_callback = ModelCheckpoint(
    dirpath=experiment_dir.joinpath("checkpoints"),
    filename="checkpoint-{epoch}-{step:06d}-{train_loss:.2f}-{val_loss:.2f}",
    save_top_k=checkpoint_top_k,
    every_n_train_steps=checkpoint_n_step,
    monitor="val_loss",
  )

It is a documented usage to make the monitor metric val_loss.

The problem is that these values might not exist, giving the warning or they are stale - because val_step is run after the checkpoint has been processed, new val metrics are not considered.

What version are you seeing the problem on?

v2.3, v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- PyTorch Lightning Version (e.g., 2.4.0): #- PyTorch Version (e.g., 2.4): #- Python version (e.g., 3.12): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): ```

More info

No response

PheelaV commented 1 month ago

I am not in a situation at the moment where I would be able to produce a convenient repro, but would be keen to take this on as a first contribution in September time if possible.

Naively I can think of a of either: 1) Adding a more reasonable error message, checking if any validations have been run if monitors come from the val step or possibly warn the user even if the val frequency is higher than the checkpoint frequency 2) Make sure at least one stage where the monitors are sourced from is run before checkpointing or delay checkpointing until those monitors are available or make sure frequencies of the checkpointing is not higher than those of the monitors

PheelaV commented 1 month ago

Technically a dup