sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
4.02k stars 639 forks source link

Optimizer state not loading from checkpoint #959

Open Vodolay opened 2 years ago

Vodolay commented 2 years ago

Expected behavior

Want to resume training form a check point:

trainer.fit( tuft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path = 'lightning_logs/best.ckpt', )

Actual behavior

KeyError Traceback (most recent call last) /tmp/ipykernel_97035/4236316375.py in 3 train_dataloaders=train_dataloader, 4 val_dataloaders=val_dataloader, ----> 5 ckpt_path = 'lightning_logs/best.ckpt', 6 )

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 770 self.strategy.model = model 771 self._call_and_handle_interrupt( --> 772 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 773 ) 774

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, kwargs) 722 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, *kwargs) 723 else: --> 724 return trainer_fn(args, kwargs) 725 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 726 except KeyboardInterrupt as exception:

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 810 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None 811 ) --> 812 results = self._run(model, ckpt_path=self.ckpt_path) 813 814 assert self.state.stopped

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path) 1231 # restore optimizers, etc. 1232 log.detail(f"{self.class.name}: restoring training state") -> 1233 self._checkpoint_connector.restore_training_state() 1234 1235 self._checkpoint_connector.resume_end()

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_training_state(self) 202 if self.trainer.state.fn == TrainerFn.FITTING: 203 # restore optimizers and schedulers state --> 204 self.restore_optimizers_and_schedulers() 205 206 def restore_precision_plugin_state(self) -> None:

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers_and_schedulers(self) 308 " This is probably due to ModelCheckpoint.save_weights_only being set to True." 309 ) --> 310 self.restore_optimizers() 311 312 if "lr_schedulers" not in self._loaded_checkpoint:

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py in restore_optimizers(self) 323 324 # restore the optimizers --> 325 self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint) 326 327 def restore_lr_schedulers(self) -> None:

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in load_optimizer_state_dict(self, checkpoint) 322 optimizer_states = checkpoint["optimizer_states"] 323 for optimizer, opt_state in zip(self.optimizers, optimizer_states): --> 324 optimizer.load_state_dict(opt_state) 325 optimizer_to_device(optimizer, self.root_device) 326

~/anaconda3/envs/tft/lib/python3.7/site-packages/torch/optim/optimizer.py in load_state_dict(self, state_dict) 185 param_groups = [ 186 update_group(g, ng) for g, ng in zip(groups, saved_groups)] --> 187 self.setstate({'state': state, 'param_groups': param_groups}) 188 189 def zero_grad(self, set_to_none: bool = False):

~/anaconda3/envs/tft/lib/python3.7/site-packages/pytorch_forecasting/optim.py in setstate(self, state) 131 def setstate(self, state: dict) -> None: 132 super().setstate(state) --> 133 self.radam_buffer = state["radam_buffer"] 134 self.alpha = state["alpha"] 135 self.k = state["k"]

KeyError: 'radam_buffer'

Code to reproduce the problem

Paste the command(s) you ran and the output. Including a link to a colab notebook will speed up issue resolution. If there was a crash, please include the traceback here. The code used to initialize the TimeSeriesDataSet and model should be also included.

rrtjr commented 2 years ago

I experienced the same issue and previously reported the same bug too. [#928] The only current workaround I have is to change the default optimizer from ranger to adam. Although changing it to adam can have model performance implications.

Unless anyone else has a suggestion, please let me know too. Thanks

lee335 commented 2 years ago

Hi, I think I found the reason for this problem and a new workaround:

https://github.com/jdb78/pytorch-forecasting/issues/928#issuecomment-1256432343