Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.6k stars 3.31k forks source link

trainer.test() with given checkpoint logs last epoch instead of checkpoint epoch #20052

Open markussteindl opened 3 weeks ago

markussteindl commented 3 weeks ago

Bug description

Testing from a given checkpoint leads to logging the epoch number of the last checkpoint instead of the checkpoint specified:

trainer = Trainer(..., max_epochs=10)
lightning_module = MyLightningModule(...)
datamodule = MyDatamodule()

trainer.fit(lightning_module , datamodule=datamodule)

trainer.test(lightning_module , datamodule=datamodule, ckpt_path="last")     # <-- ok: logs correct epoch and step
ckpt_path="/.../checkpoints/epoch=2-step=396.ckpt"
trainer.test(lightning_module , datamodule=datamodule, ckpt_path=ckpt_path)  # <-- incorrect: logs last epoch and step

The second test logs epoch 10 instead of epoch 2. Similarly, the step number of the second test is incorrect.

What version are you seeing the problem on?

v2.2.1

heth27 commented 3 weeks ago

I guess this could be caused by the same as https://github.com/Lightning-AI/pytorch-lightning/issues/18060. The checkpoint callback is not the last callback called, and thus some loop counter are not updated. Have a look at the fields mentioned in https://github.com/Lightning-AI/pytorch-lightning/issues/18060#issuecomment-2080180970 and see if this explains the behavior you notice, it might also offer you a workaround.