Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

ModelCheckpoint: save_top_k > 1 cannot recognize ordering of models from ckpt names #9944

Open breznak opened 3 years ago

breznak commented 3 years ago

🐛 Bug

Short: ModelCheckpoint callback with save_top_k does not use semantic meaning (does not reflect order of models) in naming of the files.

Say for save_top_k=5 it saves 5 best models into 5 files, but the user does not know which is which! (The ModelCheckpoint does, but only for the training session, not if we want to resume).

My motivation: I want to access "2nd best" model.

To Reproduce

Reproducible code: https://github.com/PyTorchLightning/pytorch-lightning/issues/9944#issuecomment-944610705

Using this ModelCheckpoint config for the example


model_checkpoint:
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
  monitor: 'val/loss'       # name of the logged metric
  save_top_k: 5             # save k best models (-1 save all, 0 don't save)
  save_last: True           # always save model from last epoch
  verbose: True             # show more detailed info during training
  mode: min                 # can be "max" or "min"
  dirpath: 'xxx'
  filename: 'best'          # use the current epoch number for naming the checkpoint, metrics may also be used

The ModelCheckpoint 1just cyclically saves the new best model as best_v{current_epoch % k}.ckpt, as seen from the following training log:

Training: -1it [00:00, ?it/s]EPOCH 0
Epoch 0: 100%|██████████| 2100/2100 [00:21<00:00, 96.04it/s, lostrain error:            0.3088467121
        validation error:       0.1557712555               
Epoch 0, global step 2055: val/loss reached 0.15577 (best 0.15577), saving model to "xxx/best.ckpt" as top 5
EPOCH 1
Epoch 1: 100%|██████████| 2100/2100 [00:20<00:00, 100.72it/s, lotrain error:            0.3018136621
        validation error:       0.1548990458               
Epoch 1, global step 4111: val/loss reached 0.15490 (best 0.15490), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 2
Epoch 2: 100%|██████████| 2100/2100 [00:21<00:00, 97.58it/s, lostrain error:            0.2986578345
        validation error:       0.1544228494               
Epoch 2, global step 6167: val/loss reached 0.15442 (best 0.15442), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 3
Epoch 3: 100%|██████████| 2100/2100 [00:21<00:00, 98.33it/s, lostrain error:            0.2965526581
        validation error:       0.1539662182               
Epoch 3, global step 8223: val/loss reached 0.15397 (best 0.15397), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 4
Epoch 4: 100%|██████████| 2100/2100 [00:21<00:00, 99.14it/s, lostrain error:            0.2950037122
        validation error:       0.1536256075               
Epoch 4, global step 10279: val/loss reached 0.15363 (best 0.15363), saving model to "xxx/best-v4.ckpt" as top 5
EPOCH 5
Epoch 5: 100%|██████████| 2100/2100 [00:21<00:00, 97.18it/s, lostrain error:            0.2937672734
        validation error:       0.1534034163               
Epoch 5, global step 12335: val/loss reached 0.15340 (best 0.15340), saving model to "xxx/best.ckpt" as top 5
EPOCH 6
Epoch 6: 100%|██████████| 2100/2100 [00:21<00:00, 98.24it/s, lostrain error:            0.2926785052
        validation error:       0.1531589478               
Epoch 6, global step 14391: val/loss reached 0.15316 (best 0.15316), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 7
Epoch 7: 100%|██████████| 2100/2100 [00:21<00:00, 96.33it/s, lostrain error:            0.2916747034
        validation error:       0.1529426873               
Epoch 7, global step 16447: val/loss reached 0.15294 (best 0.15294), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 8
Epoch 8: 100%|██████████| 2100/2100 [00:22<00:00, 92.43it/s, lostrain error:            0.2907347977
        validation error:       0.1527983993               
Epoch 8, global step 18503: val/loss reached 0.15280 (best 0.15280), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 9
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.61it/s, lotrain error:        ]   0.2898018062
        validation error:       0.1526378989               
Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt" as top 5
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.42it/s, loss=0.289, v_num=gdir]Saving latest checkpoint...
Setting period from checkpoint test_set

In this example, the real best model is:

Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt"

Now the trick is that in trainer.test(..., ckpt_path="best") (same for validate() and predict() ) the "best" is not "best.ckpt" but the best filename that only the callback knows, as seen from the following code that is used by above methods:


    def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
        if ckpt_path is None:
            return

        fn = self.state.fn.value

        if ckpt_path == "best":
            # if user requests the best checkpoint but we don't have it, error
            if not self.checkpoint_callback.best_model_path:
                if self.fast_dev_run:
                    raise MisconfigurationException(
                        f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
                        f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
                    )
                raise MisconfigurationException(
                    f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
                )
            # load best weights
            ckpt_path = self.checkpoint_callback.best_model_path

...

Problems:

Expected behavior

The ModelCheckpoint callback would keep an ordering of the top_k best models. And re-save them accordingly on each change. So best.ckpt = always the best model, best_v4.ckpt = 5th best model. The filenames (semantics) would remain the same during training, the content (models) of the files would update.

Benefits:

Environment

latest PL 1.4.9

Programmer-RD-AI commented 3 years ago

hi, can you send the code,

WIth best regards, Ranuga

breznak commented 3 years ago

hi, it's the standard PL ModelCheckpoint callback with save_top_k: 5 -- sorry, I've updated the above post, there should be "5". I'll try reproducible with boring model

Programmer-RD-AI commented 3 years ago

hi, it's the standard PL ModelCheckpoint callback with save_top_k: 5 -- sorry, I've updated the above post, there should be "5". I'll try reproducible with boring model

ok can you send the reproducible code with the boring model.

With best regards, Ranuga

breznak commented 3 years ago

so do you want the files saved as "2nd best model" file name?

no, what I want to achieve is:

I'll post BoringModel reproducible soon Thank you!

Programmer-RD-AI commented 3 years ago

ok,

With best regards

breznak commented 3 years ago

BoringModel reproducible:

import os

import torch
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        # random to make some epochs be not better than previous
        loss = torch.randn(1)*self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=8)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=8)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=8)

    model = BoringModel()

    # new:
    best_5_models = pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(monitor='valid_loss',
                                                                                 mode='min',
                                                                                 filename='best', # problem is when the name does not have an identifier (epoch, metric, ...)
                                                                                 # ... but that needs to be the case for automatic processing of the checkpoints.
                                                                                 save_top_k=5, # problem is when > 1
                                                                                 every_n_epochs=1,
                                                                                 auto_insert_metric_name=False,
                                                                                 verbose=True)
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=11,
        weights_summary=None,
        callbacks=[best_5_models,]
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()

Output:

/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/bug_report_model.py
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

Epoch 0:  50%|█████     | 8/16 [00:00<00:00, 280.29it/s, loss=-3.01, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 16/16 [00:00<00:00, 428.06it/s, loss=-3.01, v_num=17]
Epoch 1:  38%|███▊      | 6/16 [00:00<00:00, 472.26it/s, loss=28.2, v_num=17]Epoch 0, global step 7: valid_loss reached 38.09753 (best 38.09753), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best.ckpt" as top 5
Epoch 1:  50%|█████     | 8/16 [00:00<00:00, 465.22it/s, loss=26, v_num=17]  
Validating: 0it [00:00, ?it/s]
Epoch 1: 100%|██████████| 16/16 [00:00<00:00, 698.15it/s, loss=26, v_num=17]
Epoch 2:  50%|█████     | 8/16 [00:00<00:00, 493.90it/s, loss=31.7, v_num=17]
Epoch 1, global step 15: valid_loss reached 22.19783 (best 22.19783), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v1.ckpt" as top 5
Validating: 0it [00:00, ?it/s]
Epoch 2: 100%|██████████| 16/16 [00:00<00:00, 725.31it/s, loss=31.7, v_num=17]
Epoch 3:  38%|███▊      | 6/16 [00:00<00:00, 499.78it/s, loss=17, v_num=17]  Epoch 2, global step 23: valid_loss reached -34.51896 (best -34.51896), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v2.ckpt" as top 5
Epoch 3:  50%|█████     | 8/16 [00:00<00:00, 487.56it/s, loss=10.7, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 3: 100%|██████████| 16/16 [00:00<00:00, 727.47it/s, loss=10.7, v_num=17]
Epoch 4:  50%|█████     | 8/16 [00:00<00:00, 497.47it/s, loss=17.1, v_num=17]
Epoch 3, global step 31: valid_loss reached 17.03111 (best -34.51896), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v3.ckpt" as top 5
Validating: 0it [00:00, ?it/s]
Epoch 4: 100%|██████████| 16/16 [00:00<00:00, 734.55it/s, loss=17.1, v_num=17]
Epoch 5:  50%|█████     | 8/16 [00:00<00:00, 497.49it/s, loss=9.89, v_num=17]
Epoch 4, global step 39: valid_loss reached -15.45444 (best -34.51896), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v4.ckpt" as top 5
Validating: 0it [00:00, ?it/s]
Epoch 5: 100%|██████████| 16/16 [00:00<00:00, 741.88it/s, loss=9.89, v_num=17]
Epoch 6:  31%|███▏      | 5/16 [00:00<00:00, 501.86it/s, loss=12.7, v_num=17]Epoch 5, global step 47: valid_loss reached -19.44014 (best -34.51896), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best.ckpt" as top 5
Epoch 6:  50%|█████     | 8/16 [00:00<00:00, 483.14it/s, loss=5.15, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 6: 100%|██████████| 16/16 [00:00<00:00, 723.16it/s, loss=5.15, v_num=17]
Epoch 7:  31%|███▏      | 5/16 [00:00<00:00, 439.16it/s, loss=-6.2, v_num=17] Epoch 6, global step 55: valid_loss reached -18.78973 (best -34.51896), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v1.ckpt" as top 5
Epoch 7:  50%|█████     | 8/16 [00:00<00:00, 441.48it/s, loss=-19.9, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 7: 100%|██████████| 16/16 [00:00<00:00, 672.24it/s, loss=-19.9, v_num=17]
Epoch 8:  38%|███▊      | 6/16 [00:00<00:00, 394.70it/s, loss=7.07, v_num=17]Epoch 7, global step 63: valid_loss reached -40.86721 (best -40.86721), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v3.ckpt" as top 5
Epoch 8:  50%|█████     | 8/16 [00:00<00:00, 341.67it/s, loss=3.88, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 8: 100%|██████████| 16/16 [00:00<00:00, 494.01it/s, loss=3.88, v_num=17]
Epoch 8, global step 71: valid_loss was not in top 5
Epoch 9:  50%|█████     | 8/16 [00:00<00:00, 381.04it/s, loss=1.42, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 9: 100%|██████████| 16/16 [00:00<00:00, 574.34it/s, loss=1.42, v_num=17]
Epoch 10:  31%|███▏      | 5/16 [00:00<00:00, 322.56it/s, loss=6.33, v_num=17]Epoch 9, global step 79: valid_loss was not in top 5
Epoch 10:  50%|█████     | 8/16 [00:00<00:00, 311.79it/s, loss=16.6, v_num=17]
Validating: 0it [00:00, ?it/s]
Epoch 10: 100%|██████████| 16/16 [00:00<00:00, 479.55it/s, loss=16.6, v_num=17]
Epoch 10: 100%|██████████| 16/16 [00:00<00:00, 453.05it/s, loss=16.6, v_num=17]
Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
Epoch 10, global step 87: valid_loss reached -53.42523 (best -53.42523), saving model to "/home/marek/anaconda3/envs/blackfox/lib/python3.6/site-packages/pl_examples/lightning_logs/version_17/checkpoints/best-v4.ckpt" as top 5

DATALOADER:0 TEST RESULTS
{'test_loss': -43.91978073120117}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 1259.79it/s]

Process finished with exit code 0

UPDATE: the problem comes from using a filename without "a dynamic indentifier" so not filename=best_{epoch} but just filename="best". Which is needed when we don't know (and don't care) which epoch is the best but want to checkpoints for automated processing, so the names have to be known.)

CC @Programmer-RD-AI I have a reproducible and details on the cause of the naming problem.

Programmer-RD-AI commented 3 years ago

hi @breznak

I will check this issue tomorrow,

With best regards, Ranuga

Programmer-RD-AI commented 3 years ago

hi, I changed the file names in https://github.com/PyTorchLightning/pytorch-lightning/pull/9962.

Can you check the pull request and confirm that this solves your issue.

With best regards, Ranuga

Programmer-RD-AI commented 3 years ago

lightning_logs.zip

Programmer-RD-AI commented 3 years ago

hi @breznak,

What do you mean by the part of save_top_k > 1,

What is the issue with it?

How do you want the functionality to change?

With best regards, Ranuga

Programmer-RD-AI commented 3 years ago

hi #9970 is the new PR

carmocca commented 3 years ago

Hi @breznak! Thanks for the thoughtful write-up.

Correct me if I'm wrong, you are saying that the filename of the "best" checkpoints does not indicate which is the best in any way. This information is only available during runtime or by looking at the logs.

And you want to be able to know which was the best checkpoint by looking at a directory of checkpoints, without loading them or checking the logs.

Our internal saving mechanism is very simple:

https://github.com/PyTorchLightning/pytorch-lightning/blob/7b4df7bf919acfd7f7b39d780faeadb54aec9ade/pytorch_lightning/callbacks/model_checkpoint.py#L696-L719

Where the -vN suffix is only used to avoid filename clashes, not to indicate any sort of order.

I like the idea.

We could have a fixed size heap as the data structure to manage this (with heapq.heappushpop and checking the size) If we want to avoid re-naming (up to) all checkpoints in every heap update, we could pass a tuple (monitor_value. current_filepath) as the items of the heap, and on_train_end get the ordered list of values and reconvert the current filepaths into the expected filepaths that do indicate the top-K order.

Open to more ideas.

Some more thoughts:

Programmer-RD-AI commented 3 years ago

@carmocca does my PR not help with understanding what a certain saved file mean?

if it doesn't I would love to work on this issue.

WIth best regards, Ranuga

carmocca commented 3 years ago

It does not because the original poster could already pass ModelCheckpont(filepath="{epoch}-{step}-{val/loss}") if that's what he wanted.

Programmer-RD-AI commented 3 years ago

It does not because the original poster could already pass ModelCheckpont(filepath="{epoch}-{step}-{val/loss}") if that's what he wanted.

oh, ok

With best regards, Ranuga

breznak commented 3 years ago

Hi @carmocca ,

Correct me if I'm wrong, you are saying that the filename of the "best" checkpoints does not indicate which is the best in any way. This information is only available during runtime or by looking at the logs. And you want to be able to know which was the best checkpoint by looking at a directory of checkpoints, without loading them or checking the logs.

exactly, this is the idea.

Where the -vN suffix is only used to avoid filename clashes, not to indicate any sort of order.

oh, I see. So the ModelCheckpoint actually does not work correctly with a static {filename} ("best") (ie without the modifiers as {epoch} etc). It was always trying to override the same file. What I was seeing as new filenames was the "avoid filename clashes" mechanism being misused there.

Some more thoughts: Would we need to deprecate the old naming scheme and have this be opt-in until removed.?

I'd say no, as this seems as a bug-fix for a broken functionality for save_top_k > 1, and the naming of "new versions" as {filename}-v{N}.ckpt should remain and still be used for that purpose.

So I'd keep the -vN suffix, and in addition newly add the _top{k} suffix. Both can be combined, so you could have best_top2-v1.ckpt (1st copy of the 2nd best model)

What would be the best final filepath format? .ckpt, _2nd.ckpt, ...?

I'd suggest {filename}_top{k}.ckpt (which can be easily parsed, unlike "2nd, 3rd"), and with additional {filename}.ckpt -> {filename}_top1.ckpt for consistency and ease of use.

Furthermore,

Where the -vN suffix is only used to avoid filename clashes, not to indicate any sort of order.

(this might be a separate issue) but we might as well solve it here: I find this a security risk from deployment perspective! You'd expect your latest model saved in {filename}.ckpt but it might be saved to filename-vN.ckpt with just a warning somewhere in the logs. Then you might easily deploy a different (old) model on production, causing real problems.

For this I suggest adding an argument automatic_version_increment_filename: bool (can be default True for consistency). If set to False, either assert or override the file if it exists.

Thank you for analyzing this issue,

xwinxu commented 1 year ago

Curious, did this end up getting fixed?

carmocca commented 1 year ago

No. It's up for grabs