sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.98k stars 629 forks source link

Getting key error when fitting from ckpt_path to resume training #928

Open rrtjr opened 2 years ago

rrtjr commented 2 years ago

Expected behavior

I am currently fitting my TFT model and it works fine as it is initially. However, the process was interrupted so I added ckpt_path to resume training. After adding the ckpt_path , I am getting a key error. I expect the fitting should just continue after I add the checkpoint path.


trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

Actual behavior

Added ckpt_path so my model can resume training.


trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    ckpt_path=os.path.join("lightning_logs/default/version_0/checkpoints/epoch=4-step=37349.ckpt")
)

I am getting this error when I do so:

KeyError: 'radam_buffer'

Here is the stack trace for your reference:


Restoring states from the checkpoint path at lightning_logs/default/version_0/checkpoints/epoch=4-step=37349.ckpt
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:343: UserWarning: The dirpath has changed from 'lightning_logs/default/version_0/checkpoints' to 'lightning_logs/lightning_logs/version_0/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 137   
3  | prescalers                         | ModuleDict                      | 24    
4  | static_variable_selection          | VariableSelectionNetwork        | 312   
5  | encoder_variable_selection         | VariableSelectionNetwork        | 437   
6  | decoder_variable_selection         | VariableSelectionNetwork        | 354   
7  | static_context_variable_selection  | GatedResidualNetwork            | 88    
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 88    
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 88    
10 | static_context_enrichment          | GatedResidualNetwork            | 88    
11 | lstm_encoder                       | LSTM                            | 160   
12 | lstm_decoder                       | LSTM                            | 160   
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 40    
14 | post_lstm_add_norm_encoder         | AddNorm                         | 8     
15 | static_enrichment                  | GatedResidualNetwork            | 104   
16 | multihead_attn                     | InterpretableMultiHeadAttention | 76    
17 | post_attn_gate_norm                | GateAddNorm                     | 48    
18 | pos_wise_ff                        | GatedResidualNetwork            | 88    
19 | pre_output_gate_norm               | GateAddNorm                     | 48    
20 | output_layer                       | Linear                          | 35    
----------------------------------------------------------------------------------------
2.4 K     Trainable params
0         Non-trainable params
2.4 K     Total params
0.009     Total estimated model params size (MB)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-19-c17b7f032035>](https://localhost:8080/#) in <module>()
      6     train_dataloaders=train_dataloader,
      7     val_dataloaders=val_dataloader,
----> 8     ckpt_path="lightning_logs/default/version_0/checkpoints/epoch=4-step=37349.ckpt"
      9 )

9 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    770         self.strategy.model = model
    771         self._call_and_handle_interrupt(
--> 772             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    773         )
    774 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    722                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    723             else:
--> 724                 return trainer_fn(*args, **kwargs)
    725         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    726         except KeyboardInterrupt as exception:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    810             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    811         )
--> 812         results = self._run(model, ckpt_path=self.ckpt_path)
    813 
    814         assert self.state.stopped

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
   1231         # restore optimizers, etc.
   1232         log.detail(f"{self.__class__.__name__}: restoring training state")
-> 1233         self._checkpoint_connector.restore_training_state()
   1234 
   1235         self._checkpoint_connector.resume_end()

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_training_state(self)
    202         if self.trainer.state.fn == TrainerFn.FITTING:
    203             # restore optimizers and schedulers state
--> 204             self.restore_optimizers_and_schedulers()
    205 
    206     def restore_precision_plugin_state(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_optimizers_and_schedulers(self)
    308                     " This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
    309                 )
--> 310             self.restore_optimizers()
    311 
    312         if "lr_schedulers" not in self._loaded_checkpoint:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py](https://localhost:8080/#) in restore_optimizers(self)
    323 
    324         # restore the optimizers
--> 325         self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
    326 
    327     def restore_lr_schedulers(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py](https://localhost:8080/#) in load_optimizer_state_dict(self, checkpoint)
    322         optimizer_states = checkpoint["optimizer_states"]
    323         for optimizer, opt_state in zip(self.optimizers, optimizer_states):
--> 324             optimizer.load_state_dict(opt_state)
    325             optimizer_to_device(optimizer, self.root_device)
    326 

[/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in load_state_dict(self, state_dict)
    185         param_groups = [
    186             update_group(g, ng) for g, ng in zip(groups, saved_groups)]
--> 187         self.__setstate__({'state': state, 'param_groups': param_groups})
    188 
    189     def zero_grad(self, set_to_none: bool = False):

[/usr/local/lib/python3.7/dist-packages/pytorch_forecasting/optim.py](https://localhost:8080/#) in __setstate__(self, state)
    131     def __setstate__(self, state: dict) -> None:
    132         super().__setstate__(state)
--> 133         self.radam_buffer = state["radam_buffer"]
    134         self.alpha = state["alpha"]
    135         self.k = state["k"]

KeyError: 'radam_buffer'

Code to reproduce the problem

Relevant code for your reference as well


# model params
demand_model_path = os.path.join(model_path, "demand")
max_prediction_length = 12 * 24 * 7
max_encoder_length = 12 * 24 * 7 * 2
training_cutoff = df["time_idx"].max() - max_prediction_length
target = 'mkt_reqt'
group_id = 'region_name'
batch_size = 32
max_epochs = 8
learning_rate = 0.06
hidden_size=4
hidden_continuous_size=2
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=True, mode="min")
lr_logger = LearningRateMonitor()  
logger = TensorBoardLogger("lightning_logs")

training = TimeSeriesDataSet(
    df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target=target,
    group_ids=[group_id],
    min_encoder_length=max_encoder_length // 2, 
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[group_id],
    time_varying_known_categoricals=["a", "b", "c", "d"],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        target
    ],
    target_normalizer=GroupNormalizer(
        groups=[group_id], transformation="softplus"
    ),  
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=4)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=4)

pl.seed_everything(42)

trainer = pl.Trainer(
    max_epochs=max_epochs,
    gpus=1,
    weights_summary="top",
    gradient_clip_val=0.1,
    #limit_train_batches=30,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=learning_rate,
    hidden_size=hidden_size,
    attention_head_size=1,
    dropout=0.1,  
    hidden_continuous_size=hidden_continuous_size,
    output_size=7, 
    loss=QuantileLoss(),
    reduce_on_plateau_patience=4,
)

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    # I get the error when I add this line. 
    #ckpt_path="lightning_logs/default/version_0/checkpoints/epoch=4-step=37349.ckpt"
)

I appreciate all the support I can get, thank you very much.

rrtjr commented 2 years ago

Testing it further includes that this scenario only happens with the default ranger optimizer. Setting it to adam works just as fine.

lee335 commented 2 years ago

Hi, I also encountered this issue and I found the reason why it is happening.

First at all, with pytorch-forecasting==0.9.2 this problem does not occur.

The reason is that for pytorch-forecasting>=0.10.0 the __setstate__ method in Ranger has changed:

https://github.com/jdb78/pytorch-forecasting/blob/4ad86c3f847d6281cd0b2deccce5fe873cf692f9/pytorch_forecasting/optim.py#L131

However, pytorch is calling __setstate__ like this:

self.__setstate__({'state': state, 'param_groups': param_groups})

So in __setstate__ in Ranger it should be something like this:

self.radam_buffer = state["state"]["radam_buffer"]

KaikePing commented 1 year ago

Hi, I also encountered this issue and I found the reason why it is happening.

First at all, with pytorch-forecasting==0.9.2 this problem does not occur.

The reason is that for pytorch-forecasting>=0.10.0 the __setstate__ method in Ranger has changed:

https://github.com/jdb78/pytorch-forecasting/blob/4ad86c3f847d6281cd0b2deccce5fe873cf692f9/pytorch_forecasting/optim.py#L131

However, pytorch is calling __setstate__ like this:

self.__setstate__({'state': state, 'param_groups': param_groups})

So in __setstate__ in Ranger it should be something like this:

self.radam_buffer = state["state"]["radam_buffer"]

This bug still exists in version 0.10.3. We can add state = state["state"] in line 133 of optim.py to temporarily avoid key errors.