Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.37k stars 3.38k forks source link

ModelCheckpoint() callback don not save any checkpoint #17877

Closed zxyl1003 closed 1 year ago

zxyl1003 commented 1 year ago

Bug description

Calling trainer.test() after calling trainer.fit() to train the model raises an error:

ValueError: `.test(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.

I checked the file and found that ModelCheckpoint() did not create the specified folder and did not save any checkpoints. And other callbacks and logger are working fine. I am confused why this is

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import argparse
from typing import Union

import torch
from lightning.pytorch import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger

from data_script import DInterface
from models import SrInterface, SegInterface

def main(args):
    seed_everything(args.seed)
    torch.set_float32_matmul_precision("high")
    # 修改保存时的模型名
    args.save_model_name = args.model_name
    # 数据模块和模型训练模块
    data_module = DInterface(**vars(args))
    model = SrInterface(**vars(args)) if args.task_type == "sr" else SegInterface(**vars(args))
    # logger and callbacks
    logger = CSVLogger(save_dir=args.log_dir, name=args.save_model_name)
    ckpt_fn = "best-{epoch}-{val_psnr:.4f}-{val_ssim:.4f}" if args.task_type == "sr" else "best-{epoch}-{val_acc:.4f}"
    monitor_index = "val_psnr" if args.task_type == "sr" else "val_acc"
    callbacks = [ModelCheckpoint(monitor=monitor_index,
                                 dirpath=args.checkpoint_dir + "/" + args.save_model_name,
                                 filename=ckpt_fn,
                                 save_top_k=1,
                                 mode="max",
                                 save_last=True),
                 TQDMProgressBar(refresh_rate=1),
                 LearningRateMonitor(logging_interval="epoch")]

    trainer = Trainer(logger=logger,
                      callbacks=callbacks,
                      accelerator="gpu",
                      max_epochs=1,
                      fast_dev_run=False,
                      precision=args.precision,
                      log_every_n_steps=args.flush_logs_every_n_steps)
    trainer.fit(model, data_module, ckpt_path="last" if args.resume_from_ckpt else None)
    trainer.test(model, data_module, ckpt_path="best")
    trainer.predict(model, data_module, ckpt_path="best")

Error messages and logs

`Trainer.fit` stopped: `max_epochs=1` reached.
Traceback (most recent call last):
  File "F:\Python\mycode\CropedSR\main.py", line 166, in <module>
    main(args)
  File "F:\Python\mycode\CropedSR\main.py", line 43, in main
    trainer.test(model, data_module, ckpt_path="best")
  File "D:\miniconda3\envs\torch2lighting\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 706, in test
    return call._call_and_handle_interrupt(
  File "D:\miniconda3\envs\torch2lighting\lib\site-packages\lightning\pytorch\trainer\call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "D:\miniconda3\envs\torch2lighting\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 746, in _test_impl
    ckpt_path = self._checkpoint_connector._select_ckpt_path(
  File "D:\miniconda3\envs\torch2lighting\lib\site-packages\lightning\pytorch\trainer\connectors\checkpoint_connector.py", line 107, in _select_ckpt_path
    ckpt_path = self._parse_ckpt_path(
  File "D:\miniconda3\envs\torch2lighting\lib\site-packages\lightning\pytorch\trainer\connectors\checkpoint_connector.py", line 174, in _parse_ckpt_path
    raise ValueError(
ValueError: `.test(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.

Environment

Current environment ``` #- Lightning Component: Trainer, LightningModule #- PyTorch Lightning Version: 2.0.0 #- PyTorch Version: 2.0.0+cu117 #- Python version: 3.10 #- OS: Windows #- CUDA/cuDNN version: cu117 #- GPU models and configuration: rtx 3060 6g #- How you installed Lightning(`conda`, `pip`, source): pip ```

More info

No response

pcwanan commented 8 months ago

i get the same error, how to solve it?