Open godaup opened 1 week ago
I noticed an issue today that is (I think) a manifestation of this bug.
I was training a model with:
save_top_k=5
save_last=True
mode='max'
However, I was monitoring a "lower is better" signal mistake, so the 5 checkpoints that were saved were from the first 5 epochs (where error was highest). After loading up last.ckpt
, the validation performance was identical to the final state of the model as reported on Tensorboard (and did not match the performance of any of the other saved checkpoints).
We can see that last.ckpt
was saved well after the others:
drwxr-xr-x 1 ecole domain users 0 Aug 27 23:31 .
drwxr-xr-x 1 ecole domain users 0 Aug 27 23:29 ..
-rw------- 1 ecole domain users 83M Aug 27 23:29 'epoch=0-step=30.ckpt'
-rw------- 1 ecole domain users 83M Aug 27 23:30 'epoch=1-step=60.ckpt'
-rw------- 1 ecole domain users 83M Aug 27 23:30 'epoch=2-step=90.ckpt'
-rw------- 1 ecole domain users 83M Aug 27 23:30 'epoch=3-step=120.ckpt'
-rw------- 1 ecole domain users 83M Aug 27 23:31 'epoch=4-step=150.ckpt'
-rw------- 1 ecole domain users 83M Aug 27 23:37 last.ckpt
I was confused by this behavior, but it seems like this bug could explain it.
Bug description
According to the discussion in #4335 the intended functionality of
save_last
is the deterministic access to the last stored checkpoint in order to resume training. The documentation states similarly:But unfortunately
_save_last_checkpoint
may be called without any previous checkpoint existing. Some example configuration:This will create a
last.ckpt
at the end of very epoch (either train or validation depending onsave_on_train_epoch_end
) - though no other checkpointing has triggered before. This is due to the_save_last_checkpoint
called either way at the epoch end:Hence in the end (ironically)
save_last
does what it was not intended to do: save the last epoch. In case ofsave_last='link'
this is resolved after the first triggering - afterwards the regular_save_last_checkpoint
calls will only write identical soft links to the actual checkpoint file. There are more possible configurations for this to happen:Even if after 2000 steps
_save_topk_checkpoint
omits the creation of a new file (via_save_monitor_checkpoint
detecting no improvement)_save_last_checkpoint
is triggered.These additional (unintended) creations of checkpoints may cause severe delays between epochs. I am also pretty sure this is the cause of numerous issues. A quick search revealed e.g. #20075, #19845, #17916.
Another problem is actually breaking the interplay of using multiple ModelCheckpoints - creating the chance that
last.ckpt
is actually not the latest checkpoint.The intended behavior would be to have (after some while) one checkpoint file according to the best loss value, one
temp_backup.ckpt
file and a softlinklast.ckpt
pinpointing to whatever of the above was stored later (in case training gets interrupted and needs to be resumed). But the following may happen:last.ckpt
points to that filetemp_backup.ckpt
,last.ckpt
is updated to point towardstemp_backup.ckpt
last.ckpt
pointer is only determined by the order of checkpoint execution. For both cpt_a and cpt_b whenever another epoch finishes the call to_save_last_checkpoint
will makelast.ckpt
point to their checkpoint - teven if cpt_a detects a new best validation loss and cpt_b is triggered after cpt_alast.ckpt
will for the majority of time point totemp_backup.ckpt
Also note this example where
last.ckpt
will be a corrupted symlink:During the first "end of epoch call" of
_save_last_checkpoint
it is detected that no previous checkpoint was stored (self._last_checkpoint_saved
is initialized as empty string):but in the next iteration (next epoch) the first case is satisfied and ModelCheckpoint will symlink
last.ckpt
to itself...As a clean solution I would assume that no need for a "top-level"
_save_last_checkpoint
should be required - based on a downstream trigger of_save_checkpoint
the logic to update/generate thelast.ckpt
should live. As an ad-hoc solution I tested the followingwhich resolved all problems mentioned above. As additional speed improvement the
_save_last_checkpoint
calls at the end of each epoch could also be slightly re-arrangedWhat version are you seeing the problem on?
v2.3
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
``` #- PyTorch Lightning Version 2.3.3 #- PyTorch Version 2.4.0 #- Python version 3.8.13 #- OS Linux #- CUDA version 12.2 #- GPU models and configuration: Nvidia GeForce RTX 3090 #- How you installed Lightning: via pip inside a conda env ```More info
No response