Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.45k stars 3.29k forks source link

ModelCheckpoint Doesn't Delete Old Best Checkpoints When Resuming Training #18687

Open danielzeng-gt opened 9 months ago

danielzeng-gt commented 9 months ago

Bug description

Description: When using ModelCheckpoint with the parameters top_k=1 and monitor='val_loss' during a singular training run, the behavior is as expected and only retains one 'best_val_confidence-epoch...' checkpoint.

However, in the context of cloud-based training where instances may be preempted or restarted from a checkpoint:

It should be noted we load/write checkpoints to GCS with fsspec, which allows for checkpoints to be written to and loaded directly from Google Cloud Storage (GCS).

Code Details:

There are two current ModelCheckpoint callbacks in use:

  1. The first is for saving the latest checkpoint:

    last_ckpt_callback = ModelCheckpoint(
        save_top_k= -1,
        save_last= True,
        dirpath= self.checkpoint_dir,
    )
    last_ckpt_callback.CHECKPOINT_NAME_LAST = _CHECKPOINT_NAME_LAST
  2. The second is for saving the best validation loss checkpoint:

    best_val_loss_ckpt_callback = ModelCheckpoint(
        monitor=f'val_loss',
        mode='min',
        save_top_k=1,
        auto_insert_metric_name=False,
        filename='best_val_confidence-epoch{epoch}-val_loss{{val_loss:.4e}}',
        dirpath=self.checkpoint_dir,
    )

    Environment:

    • Lightning Component: ModelCheckpoint object
    • PyTorch Lightning Version: 1.9.2
    • PyTorch Version: 1.13.0
    • Python Version: 3.10.12
    • OS: Linux
    • CUDA/cuDNN version: Build cuda_11.6.r11.6/compiler.31057947_0
    • GPU models: Nvidia A100
    • How you installed Lightning: Conda
    • Cloud: Running on GCP Cluster

What version are you seeing the problem on?

v1.9

How to reproduce the bug

1. Setup a training loop on the cloud with the aforementioned `ModelCheckpoint` callbacks.
2. Intentionally interrupt the training to simulate preemption.
3. Resume the training from the "last.ckpt".
4. Post-resumption, inspect the stored checkpoints. There should be two 'best_val_loss' checkpoints instead of one.

**Expected behavior**: Only one 'best_val_confidence-epoch...' checkpoint should remain after resumption.

**Actual behavior**: Multiple 'best_val_confidence-epoch...' checkpoints are observed after training preemption and resumption.

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` - Lightning Component: ModelCheckpoint object - PyTorch Lightning Version: 1.9.2 - PyTorch Version: 1.13.0 - Python Version: 3.10.12 - OS: Linux - CUDA/cuDNN version: Build cuda_11.6.r11.6/compiler.31057947_0 - GPU models: Nvidia A100 - How you installed Lightning: Conda - Running environment of LightningApp: Cloud, Running on GCP A100 instance ```

More info

No response

cc @carmocca @awaelchli

awaelchli commented 9 months ago

@danielzeng-gt Thanks for submitting the issue.

I read your description multiple times but I don't understand the problem. Can you try to formulate it with an example? Is it related to #17912?

danielzeng-gt commented 9 months ago

Hey Adrian, thanks for the prompt response! I looked at #17912 and it doesn't seem to be related.

I generated an example with GPT4, and I read over it and it is quite accurate in describing the problem. Please let me know if it's still confusing:

Example:

Suppose Alice is training a neural network to classify images of cats and dogs on a cloud-based preemptible instance. She's interested in keeping two kinds of checkpoints:

  1. The latest checkpoint, irrespective of its performance on validation data.
  2. The checkpoint with the best validation loss.

To achieve this, Alice uses two ModelCheckpoint callbacks as described.

Training Run 1:

  1. Alice starts her training.
  2. After epoch 1, the validation loss is 0.5. The system saves:
    • last.ckpt (The latest checkpoint)
    • best_val_confidence-epoch1-val_loss0.5e (The best checkpoint based on validation loss)
  3. Suddenly, the preemptible instance is terminated.

Training Resumption:

  1. Alice's setup detects the preemption and decides to restart the training from the last checkpoint.
  2. It loads last.ckpt and continues training.
  3. After epoch 2, the validation loss improves to 0.4. The system now tries to save:
    • A new last.ckpt (Replacing the older one)
    • best_val_confidence-epoch2-val_loss0.4e (A new best checkpoint)

Expected Behavior: Since Alice specified save_top_k=1 for the best validation loss checkpoint, she expects to find only one such checkpoint in her directory, i.e., best_val_confidence-epoch2-val_loss0.4e.

Actual Behavior: Alice finds two best validation loss checkpoints:

This indicates that the ModelCheckpoint callback did not delete the older "best" checkpoint upon resumption, leading to multiple "best" checkpoints being saved.

Implication: This behavior can be problematic especially if Alice runs multiple epochs and faces multiple preemptions. Over time, she would accumulate multiple "best" checkpoints, and it is confusing her when trying to identify the genuine best checkpoint.

Conclusion:

The bug seems to arise from a state restoration issue in the ModelCheckpoint callback when resuming training from a checkpoint. It fails to remember its previous "best" state and does not delete older checkpoints as it should.

leng-yue commented 1 month ago

I met same issue, I understand that maybe a breaking change, can wee add an option to handle that?