Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
2.69k stars 312 forks source link

Hyperparameter Tuning of iTransformer using Optuna #1009

Closed SujayC66 closed 1 month ago

SujayC66 commented 1 month ago

def config_itransformer(trial): return {

"h" : 14,

"input_size" : trial.suggest_int("input_size", 60, 360, step = 60),

    "input_size": 120,
    "n_series" : 1,
    "hidden_size" : trial.suggest_int("hidden_size", 48, 512, step = 48),
    "max_steps": 10,                                                                                               # Number of SGD steps                                                                                             # Size of input window
    "learning_rate": trial.suggest_loguniform("learning_rate", 1e-6, 1e-3),

"val_check_steps": 10,

    "random_seed": trial.suggest_int("random_seed", 1, 10),
    "n_heads": trial.suggest_int("n_heads", 4, 24, step = 4),
    "e_layers": trial.suggest_int("e_layers", 64, 512, step = 64),
    "d_layers": trial.suggest_int("d_layers", 64, 1024, step = 64),
    "d_ff": trial.suggest_int("d_ff", 128, 2048, step = 128),
    "factor": trial.suggest_int("factor", 1, 3, step = 1),
    "dropout": trial.suggest_categorical("dropout", list(np.arange(0.04,0.2,0.02))),
    "step_size": trial.suggest_categorical("step_size", [90, 120, 140, 180]),
    "use_norm": True,
    "batch_size": 32,
    "early_stop_patience_steps": 2,

}

model = AutoiTransformer(h=14,
              n_series = 1,
              loss=MSE(),
              valid_loss=MSE(),
              config=config_itransformer,
              search_alg=optuna.samplers.TPESampler(),
              backend='optuna',
              num_samples=1
                    )
      RuntimeError                              Traceback (most recent call last)

Cell In[56], line 2 1 fcst = NeuralForecast(models=[model], freq='B') ----> 2 fcst.fit(df=train_data, val_size=14)

File /opt/conda/lib/python3.10/site-packages/neuralforecast/core.py:462, in NeuralForecast.fit(self, df, static_df, val_size, sort_df, use_init_models, verbose, id_col, time_col, target_col, distributed_config) 459 self._reset_models() 461 for i, model in enumerate(self.models): --> 462 self.models[i] = model.fit( 463 self.dataset, val_size=val_size, distributed_config=distributed_config 464 ) 466 self._fitted = True

File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_auto.py:424, in BaseAuto.fit(self, dataset, val_size, test_size, random_seed, distributed_config) 412 results = self._optuna_tune_model( 413 cls_model=self.cls_model, 414 dataset=dataset, (...) 421 distributed_config=distributed_config, 422 ) 423 best_config = results.best_trial.user_attrs["ALL_PARAMS"] --> 424 self.model = self._fit_model( 425 cls_model=self.cls_model, 426 config=best_config, 427 dataset=dataset, 428 val_size=val_size * self.refit_with_val, 429 test_size=test_size, 430 distributed_config=distributed_config, 431 ) 432 self.results = results 434 # Added attributes for compatibility with NeuralForecast core

File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_auto.py:357, in BaseAuto._fit_model(self, cls_model, config, dataset, val_size, test_size, distributed_config) 353 def _fit_model( 354 self, cls_model, config, dataset, val_size, test_size, distributed_config=None 355 ): 356 model = cls_model(**config) --> 357 model = model.fit( 358 dataset, 359 val_size=val_size, 360 test_size=test_size, 361 distributed_config=distributed_config, 362 ) 363 return model

File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_multivariate.py:537, in BaseMultivariate.fit(self, dataset, val_size, test_size, random_seed, distributed_config) 533 if distributed_config is not None: 534 raise ValueError( 535 "multivariate models cannot be trained using distributed data parallel." 536 ) --> 537 return self._fit( 538 dataset=dataset, 539 batch_size=self.n_series, 540 valid_batch_size=self.n_series, 541 val_size=val_size, 542 test_size=test_size, 543 random_seed=random_seed, 544 shuffle_train=False, 545 distributed_config=None, 546 )

File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_model.py:219, in BaseModel._fit(self, dataset, batch_size, valid_batch_size, val_size, test_size, random_seed, shuffle_train, distributed_config) 217 model = self 218 trainer = pl.Trainer(**model.trainer_kwargs) --> 219 trainer.fit(model, datamodule=datamodule) 220 model.metrics = trainer.callback_metrics 221 model.dict.pop("_trainer", None)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 )

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 44 return trainer_fn(args, kwargs) 46 except _TunerExitException: 47 _call_teardown_hook(trainer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 573 assert self.state.fn is not None 574 ckpt_path = self._checkpoint_connector._select_ckpt_path( 575 self.state.fn, 576 ckpt_path, 577 model_provided=True, 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 582 assert self.state.stopped 583 self.training = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path) 982 self._signal_connector.register_signal_handlers() 984 # ---------------------------- 985 # RUN THE TRAINER 986 # ---------------------------- --> 987 results = self._run_stage() 989 # ---------------------------- 990 # POST-Training CLEAN UP 991 # ---------------------------- 992 log.debug(f"{self.class.name}: trainer tearing down")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self) 1031 self._run_sanity_check() 1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1033 self.fit_loop.run() 1034 return None 1035 raise RuntimeError(f"Unexpected state {self.state}")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self) 203 try: 204 self.on_advance_start() --> 205 self.advance() 206 self.on_advance_end() 207 self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self) 361 with self.trainer.profiler.profile("run_training_epoch"): 362 assert self._data_fetcher is not None --> 363 self.epoch_loop.run(self._data_fetcher)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher) 139 try: 140 self.advance(data_fetcher) --> 141 self.on_advance_end(data_fetcher) 142 self._restarting = False 143 except StopIteration:

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher) 291 if not self._should_accumulate(): 292 # clear gradients to not leave any unused memory during validation 293 call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad") --> 295 self.val_loop.run() 296 self.trainer.training = True 297 self.trainer._logger_connector._first_loop_iter = first_loop_iter

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:182, in _no_grad_context.._decorator(self, *args, *kwargs) 180 context_manager = torch.no_grad 181 with context_manager(): --> 182 return loop_run(self, args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:142, in _EvaluationLoop.run(self) 140 self._restarting = False 141 self._store_dataloader_outputs() --> 142 return self.on_run_end()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:268, in _EvaluationLoop.on_run_end(self) 265 self.trainer._logger_connector.log_eval_end_metrics(all_logged_outputs) 267 # hook --> 268 self._on_evaluation_end() 270 # enable train mode again 271 self._on_evaluation_model_train()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:313, in _EvaluationLoop._on_evaluation_end(self, *args, kwargs) 311 trainer = self.trainer 312 hook_name = "on_test_end" if trainer.testing else "on_validation_end" --> 313 call._call_callback_hooks(trainer, hook_name, *args, *kwargs) 314 call._call_lightning_module_hook(trainer, hook_name, args, kwargs) 315 call._call_strategy_hook(trainer, hook_name, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:208, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, *kwargs) 206 if callable(fn): 207 with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): --> 208 fn(trainer, trainer.lightning_module, args, **kwargs) 210 if pl_module: 211 # restore current_fx when nested context 212 pl_module._current_fx_name = prev_fx_name

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:196, in EarlyStopping.on_validation_end(self, trainer, pl_module) 194 if self._check_on_train_epoch_end or self._should_skip_check(trainer): 195 return --> 196 self._run_early_stopping_check(trainer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:202, in EarlyStopping._run_early_stopping_check(self, trainer) 199 """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" 200 logs = trainer.callback_metrics --> 202 if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run 203 logs 204 ): # short circuit if metric not present 205 return 207 current = logs[self.monitor].squeeze()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:153, in EarlyStopping._validate_condition_metric(self, logs) 151 if monitor_val is None: 152 if self.strict: --> 153 raise RuntimeError(error_msg) 154 if self.verbose > 0: 155 rank_zero_warn(error_msg, category=RuntimeWarning)

RuntimeError: Early stopping conditioned on metric ptl/val_loss which is not available. Pass in or modify your EarlyStopping callback to use any of the following: train_loss, train_loss_step, train_loss_epoch

SujayC66 commented 1 month ago

Please figure out the issue.

elephaint commented 1 month ago

Hi,

Can you provide a standalone piece of code that I can run that reproduces the issue?

That said, the RuntimeError seems to provide a clue. I'd remove the val_check_steps and early_stop_patience_steps from the config and see what happens. But this is a bit guessing, a piece of code that I can run would help a lot.

Some generic tips based on the above code:

SujayC66 commented 1 month ago

Yeah, when I remove early_stop_patience_steps parameter the code runs properly but its of no use as for every trial model will run for all epochs.Please find the code here: https://github.com/SujayC66/iTransformer

jmoralez commented 1 month ago

You have to set refit_with_val=True in the AutoiTransformer constructor.

@elephaint we may want to do this automatically if early_stopping_patience_steps is positive because refit_with_val defaults to False and causes this interaction.

elephaint commented 1 month ago

You have to set refit_with_val=True in the AutoiTransformer constructor.

@elephaint we may want to do this automatically if early_stopping_patience_steps is positive because refit_with_val defaults to False and causes this interaction.

Yes that makes a lot of sense

SujayC66 commented 1 month ago

refit_with_val = True worked. Thank you so much.