jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.85k stars 609 forks source link

Training cannot be resumed from checkpoint (optimizer's radam parameter missing) #1233

Open terbed opened 1 year ago

terbed commented 1 year ago

I wanted to resume TFT training from the last checkpoint:

ckpt_path = "logs/lightning_logs/version_4/checkpoints/last.ckpt"

# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("logs")  # logging results to Weights & Biases
model_checkpoint = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    save_last=True,

)

pl.seed_everything(42)

# load trainer from last checkpoint
trainer = pl.Trainer(
    max_epochs=200,
    accelerator='gpu',
    gradient_clip_val=0.1,
    # limit_train_batches=30,  # coment in for training, running validation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback, model_checkpoint],
    logger=logger,
    # resume_from_checkpoint= ckpt_path
)

# load model from last checkpoint
tft = TemporalFusionTransformer.load_from_checkpoint(checkpoint_path=ckpt_path)

# resume fitting network
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    ckpt_path = ckpt_path
)

The model weights are successfully loaded, but the optimizer missing a key:

KeyError Traceback (most recent call last) Input In [22], in <cell line: 55>() 30 tft = TemporalFusionTransformer.load_from_checkpoint( 31 checkpoint_path=ckpt_path, 32 learning_rate=0.01, (...) 39 reduce_on_plateau_patience=5, 40 ) 42 # tft = TemporalFusionTransformer.from_dataset( 43 # training, 44 # learning_rate=0.01, (...) 53 54 # resume fitting network ---> 55 trainer.fit( 56 tft, 57 train_dataloaders=train_dataloader, 58 val_dataloaders=val_dataloader, 59 #ckpt_path = ckpt_path 60 )

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py:608, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 606 raise TypeError(f"Trainer.fit() requires a LightningModule, got: {model.class.qualname}") 607 self.strategy._lightning_module = model --> 608 call._call_and_handle_interrupt( 609 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 610 )

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) 37 else: ---> 38 return trainer_fn(args, kwargs) 40 except _TunerExitException: 41 trainer._call_teardown_hook()

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 643 ckpt_path = ckpt_path or self.resume_from_checkpoint 644 self._ckpt_path = self._checkpoint_connector._set_ckpt_path( 645 self.state.fn, 646 ckpt_path, # type: ignore[arg-type] 647 model_provided=True, 648 model_connected=self.lightning_module is not None, 649 ) --> 650 self._run(model, ckpt_path=self.ckpt_path) 652 assert self.state.stopped 653 self.training = False

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py:1099, in Trainer._run(self, model, ckpt_path) 1097 # restore optimizers, etc. 1098 log.detail(f"{self.class.name}: restoring training state") -> 1099 self._checkpoint_connector.restore_training_state() 1101 self._checkpoint_connector.resume_end() 1103 results = self._run_stage()

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:286, in CheckpointConnector.restore_training_state(self) 283 assert self.trainer.state.fn is not None 284 if self.trainer.state.fn == TrainerFn.FITTING: 285 # restore optimizers and schedulers state --> 286 self.restore_optimizers_and_schedulers()

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:382, in CheckpointConnector.restore_optimizers_and_schedulers(self) 377 if "optimizer_states" not in self._loaded_checkpoint: 378 raise KeyError( 379 "Trying to restore optimizer state but checkpoint contains only the model." 380 " This is probably due to ModelCheckpoint.save_weights_only being set to True." 381 ) --> 382 self.restore_optimizers() 384 if "lr_schedulers" not in self._loaded_checkpoint: 385 raise KeyError( 386 "Trying to restore learning rate scheduler state but checkpoint contains only the model." 387 " This is probably due to ModelCheckpoint.save_weights_only being set to True." 388 )

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:397, in CheckpointConnector.restore_optimizers(self) 394 return 396 # restore the optimizers --> 397 self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)

File /usr/local/lib/python3.9/dist-packages/pytorch_lightning/strategies/strategy.py:368, in Strategy.load_optimizer_state_dict(self, checkpoint) 366 optimizer_states = checkpoint["optimizer_states"] 367 for optimizer, opt_state in zip(self.optimizers, optimizer_states): --> 368 optimizer.load_state_dict(opt_state) 369 _optimizer_to_device(optimizer, self.root_device)

File /usr/local/lib/python3.9/dist-packages/torch/optim/optimizer.py:210, in Optimizer.load_state_dict(self, state_dict) 207 return new_group 208 param_groups = [ 209 update_group(g, ng) for g, ng in zip(groups, saved_groups)] --> 210 self.setstate({'state': state, 'param_groups': param_groups})

File /usr/local/lib/python3.9/dist-packages/pytorch_forecasting/optim.py:133, in Ranger.setstate(self, state) 131 def setstate(self, state: dict) -> None: 132 super().setstate(state) --> 133 self.radam_buffer = state["radam_buffer"] 134 self.alpha = state["alpha"] 135 self.k = state["k"]

KeyError: 'radam_buffer'

I used the default optimizer which is the "ranger". ~~The problem might be that the hyperparameters of this optimizer are not correctly initialized: https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/optim.html~~

I will try to make a run with a simple adam optimizer and see whether its parameters can be resumed flawlessly.

Update: There is also a problem loading the states of another optimizer. The duration of one epoch is very long, so these checkpoints are before the epoch ends, could this be a problem? It should be able to continue from the last step.

terbed commented 1 year ago

Update The problem still exists after saving the checkpoint after completing an epoch.

kvyaswanth commented 1 year ago

Facing the same issue

torch '1.13.1+cu116' pytorch_lightning '1.9.0' pytorch_forecasting '0.10.3' Python 3.8.10

trainer = pl.Trainer(
    gpus=1,
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.03,
    hidden_size=16,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    output_size=7,  # 7 quantiles by default
    loss=QuantileLoss(),
    # reduce learning rate if no improvement in validation loss after x epochs
    reduce_on_plateau_patience=4,
)

res = trainer.tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /lightning_logs INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] Finding best initial lr: 100% 100/100 [00:30<00:00, 4.54it/s] INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached. INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /.lr_find_962f4d20-6765-4e02-842a-adb3f6880c99.ckpt --------------------------------------------------------------------------- KeyError Traceback (most recent call last) [](https://localhost:8080/#) in 1 # find optimal learning rate ----> 2 res = trainer.tuner.lr_find( 3 tft, 4 train_dataloaders=train_dataloader, 5 val_dataloaders=val_dataloader, 15 frames [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/tuner/tuning.py](https://localhost:8080/#) in lr_find(self, model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr) 265 self.trainer.callbacks = [lr_finder_callback] + self.trainer.callbacks 266 --> 267 self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule) 268 269 self.trainer.callbacks = [cb for cb in self.trainer.callbacks if cb is not lr_finder_callback] [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 606 raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") 607 self.strategy._lightning_module = model --> 608 call._call_and_handle_interrupt( 609 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 610 ) [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) 37 else: ---> 38 return trainer_fn(*args, **kwargs) 39 40 except _TunerExitException: [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 648 model_connected=self.lightning_module is not None, 649 ) --> 650 self._run(model, ckpt_path=self.ckpt_path) 651 652 assert self.state.stopped [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path) 1086 # hook 1087 if self.state.fn == TrainerFn.FITTING: -> 1088 self._call_callback_hooks("on_fit_start") 1089 self._call_lightning_module_hook("on_fit_start") 1090 [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_callback_hooks(self, hook_name, *args, **kwargs) 1383 if callable(fn): 1384 with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): -> 1385 fn(self, self.lightning_module, *args, **kwargs) 1386 1387 if pl_module: [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/callbacks/lr_finder.py](https://localhost:8080/#) in on_fit_start(self, trainer, pl_module) 120 121 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: --> 122 self.lr_find(trainer, pl_module) [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/callbacks/lr_finder.py](https://localhost:8080/#) in lr_find(self, trainer, pl_module) 105 def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 106 with isolate_rng(): --> 107 self.optimal_lr = lr_find( 108 trainer, 109 pl_module, [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/tuner/lr_finder.py](https://localhost:8080/#) in lr_find(trainer, model, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr) 271 272 # Restore initial state of model --> 273 trainer._checkpoint_connector.restore(ckpt_path) 274 trainer.strategy.remove_checkpoint(ckpt_path) 275 trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore(self, checkpoint_path) 232 233 # restore training state --> 234 self.restore_training_state() 235 self.resume_end() 236 [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_training_state(self) 284 if self.trainer.state.fn == TrainerFn.FITTING: 285 # restore optimizers and schedulers state --> 286 self.restore_optimizers_and_schedulers() 287 288 def restore_precision_plugin_state(self) -> None: [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_optimizers_and_schedulers(self) 380 " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`." 381 ) --> 382 self.restore_optimizers() 383 384 if "lr_schedulers" not in self._loaded_checkpoint: [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_optimizers(self) 395 396 # restore the optimizers --> 397 self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint) 398 399 def restore_lr_schedulers(self) -> None: [/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/strategy.py](https://localhost:8080/#) in load_optimizer_state_dict(self, checkpoint) 366 optimizer_states = checkpoint["optimizer_states"] 367 for optimizer, opt_state in zip(self.optimizers, optimizer_states): --> 368 optimizer.load_state_dict(opt_state) 369 _optimizer_to_device(optimizer, self.root_device) 370 [/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in load_state_dict(self, state_dict) 242 param_groups = [ 243 update_group(g, ng) for g, ng in zip(groups, saved_groups)] --> 244 self.__setstate__({'state': state, 'param_groups': param_groups}) 245 246 def zero_grad(self, set_to_none: bool = False): [/usr/local/lib/python3.8/dist-packages/pytorch_forecasting/optim.py](https://localhost:8080/#) in __setstate__(self, state) 131 def __setstate__(self, state: dict) -> None: 132 super().__setstate__(state) --> 133 self.radam_buffer = state["radam_buffer"] 134 self.alpha = state["alpha"] 135 self.k = state["k"] KeyError: 'radam_buffer'
ar-mccabe commented 1 year ago

Bumping this issue; still seeing it when attempting to find optimal learning rate.

kvyaswanth commented 1 year ago

alternatively you can use adam optimizer as of now

terbed commented 1 year ago

Note that the initial issue is not related to lr finder, it is related to resuming from a checkpoint and it is not specific to the ranger optimizer. It cannot restore adam optimizer parameters as well.

kvyaswanth commented 1 year ago

restoring the state of the optimizer is within the optimizer code itself which is called by the lr_find

for me adam worked

ar-mccabe commented 1 year ago

Thanks @kvyaswanth, that worked for me as well.

antolu commented 1 year ago

Changing to Adam is a poor workaround because it forces you to use another optimizer...

I fixed this problem by adding the following code:

Implement the Ranger.state_dict method:

    def state_dict(self) -> Dict[str, Any]:
        return self.__getstate__()

As we as saving and loading the state into state['state'] subdictionary in __getstate__ and __setstate__.

    def __getstate__(self) -> dict:
        state = super().__getstate__()
        state['state']["radam_buffer"] = self.radam_buffer
        state['state']["alpha"] = self.alpha
        state['state']["k"] = self.k
        state['state']["N_sma_threshhold"] = self.N_sma_threshhold
        return state

    def __setstate__(self, state: dict) -> None:
        super().__setstate__(state)
        self.radam_buffer = state['state']["radam_buffer"]
        self.alpha = state['state']["alpha"]
        self.k = state['state']["k"]
        self.N_sma_threshhold = state['state']["N_sma_threshhold"]

This is because PyTorch only passes state and param_groups from the optimizer state_dict into the optimizer __setstate__: https://github.com/pytorch/pytorch/blob/c99895ca6f98eab834611e007de772534fd57fb9/torch/optim/optimizer.py#L432