Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.54k stars 3.39k forks source link

ModelCheckpoint Callback state loader with missing dir #15705

Open Stack-Attack opened 2 years ago

Stack-Attack commented 2 years ago

Bug description

Loading a checkpoint with the ModelCheckpoint callback on a different machine (or with a missing/moved "best_model_path" dir) results in an error and crash.

A common use case for me is to train a model (with .ckpt stored elsewhere i.e Neptune), and then pull the checkpoint from that model to another machine to continue training later. This used to work in older versions, but now breaks. Currently, the code deals with situations where the directory structure has changed, but not for larger changes in the absolute file-structure.

How to reproduce the bug

1. Run a model with ModelCheckpoint
2. Save the whole model as a .ckpt
3. Resume the .ckpt run on another machine.

Error messages and logs

Error messages and logs here please

Restoring states from the checkpoint path at /home/kyle/v6x_wsl/vrew/TTS/training/nv_fast_pitch/cache/TTS-781/step=100000.ckpt
/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:345: UserWarning: The dirpath has changed from '/home/kyle/v6x/vrew/TTS/training/nv_fast_pitch/TTS-781/.neptune/Untitled/TTS-781/checkpoints' to '/home/kyle/v6x_wsl/vrew/TTS/training/nv_fast_pitch/TTS-798/.neptune/Untitled/TTS-798/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/kyle/v6x_wsl/vrew/TTS/train.py", line 25, in train
    trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 201, in run
    self.on_advance_end()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 241, in on_advance_end
    self._run_validation()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 299, in _run_validation
    self.val_loop.run()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 207, in run
    output = self.on_run_end()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 201, in on_run_end
    self._on_evaluation_end()
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 265, in _on_evaluation_end
    self.trainer._call_callback_hooks(hook_name, *args, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 320, in on_validation_end
    self._save_last_checkpoint(trainer, monitor_candidates)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 653, in _save_last_checkpoint
    self._save_checkpoint(trainer, filepath)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 394, in _save_checkpoint
    logger.after_save_checkpoint(proxy(self))
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/utilities/rank_zero.py", line 32, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loggers/neptune.py", line 572, in after_save_checkpoint
    model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
  File "/home/kyle/anaconda3/envs/tts/lib/python3.8/site-packages/pytorch_lightning/loggers/neptune.py", line 596, in _get_full_model_name
    raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
ValueError: /home/kyle/v6x/vrew/TTS/training/nv_fast_pitch/TTS-781/.neptune/Untitled/TTS-781/checkpoints/step=100000.ckpt was expected to start with /home/kyle/v6x_wsl/vrew/TTS/training/nv_fast_pitch/TTS-798/.neptune/Untitled/TTS-798/checkpoints/.

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 1.7.7
#- Lightning App Version (e.g., 0.5.2): N/A
#- PyTorch Version (e.g., 1.10): 1.12.1
#- Python version (e.g., 3.9): 3.8
#- OS (e.g., Linux): Linux(WSL)
#- CUDA/cuDNN version: 11.3/8.3.2
#- GPU models and configuration: 3090x2
#- How you installed Lightning(`conda`, `pip`, source): conda
#- Running environment of LightningApp (e.g. local, cloud): local

More info

It seems like the simplest and maybe most straight forward solution is to not restore the ModelCheckpoint state at all if the directory has changed. There are more complex solutions (like checking each field) but given that this specific checkpoint is tightly coupled with the file structure it seems ill advised.

Stack-Attack commented 2 years ago

I wonder if it would be beneficial to allow the 'ckpt_path' of Trainer.fit() to accept a dict loaded from a .ckpt file using torch.load. Then you could manually remove problematic state dicts if required.

i.e

weights = torch.load(checkpoint, map_location=model.device)
del weights['callbacks']
trainer.fit(ckpt=weights)
ArtemSivtsov commented 1 year ago

@Stack-Attack Hello! I face the same issue, but on the same machine.

I do the following:

1) train some experiment with neptune logger, model checkpoint callback
2) start to finetune from saved full .ckpt (from 1) with different experiment name on the same machine
3) see this mistake when all paths are existing and not changed

Using torch.load and removing callback state is the best way for this issue?

Stack-Attack commented 1 year ago

@ArtemSivtsov Yes, for now if I make any large changes to a model or experiment I make a new run, load the weights manually, and train with an empty checkpoint. Roughly the following logic:

if cfg.trainer_cfg.new_run and checkpoint is not None:
    weights = torch.load(checkpoint, map_location=model.device)
    model.load_state_dict(weights["state_dict"])
    checkpoint = None
trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint)
ArtemSivtsov commented 1 year ago

@Stack-Attack Thank you so much for a quick reply! I hope Lightning team will fix that behavior later :)

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

AleksanderWWW commented 1 year ago

Hey @Stack-Attack, @ArtemSivtsov Aleksander here - eng. at Neptune.ai. I came across this issue a couple of days ago. Would it be possible for you to create/share a minimal code snippet that would help reproduce the bug?