Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.01k stars 3.36k forks source link

ModelCheckpoint does not work when using the monitor #19929

Closed QianhangFeng closed 3 months ago

QianhangFeng commented 3 months ago

Bug description

def on_validation_epoch_end(self):
    # spider score calculation on different beam size
    # ...
    self.log("spider", value=max_spider, sync_dist=True)

#################

checkpoint_callback = ModelCheckpoint(
    dirpath=model_output_dir,
    filename='{epoch:02d}-{spider:.4f}',
    monitor='spider',
    save_top_k=1,
    mode='max',
    auto_insert_metric_name=True,
    verbose=True
)
logger = CSVLogger(log_output_dir, name=config['exp_name'])
trainer = pl.Trainer(max_epochs=50, devices=[0, 1, 2, 3], num_sanity_val_steps=0, callbacks=checkpoint_callback, logger=logger)

I plan to calculate the spider score after each epoch and save the model with the highest score, but it doesn't work. However, the logger can work well and track spider on each epoch. I originally planned to save the model manually, but I found that on_validation_epoch_end would be executed by multiple GPUs and the results were different in the same epoch.

What version are you seeing the problem on?

master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @carmocca @awaelchli

awaelchli commented 3 months ago

Hi @QianhangFeng

I'm happy to help here, but I'd like to see a clear description and runnable code example to understand what is not working. Without this, it's unclear what is happening or what you are expecting. Thanks!