awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.41k stars 740 forks source link

wrong dtype in PatchTSTEstimator #3198

Open LeGmask opened 1 week ago

LeGmask commented 1 week ago

Description

When I try to train a PatchTSTEstimator, i've the following error: mat1 and mat2 must have the same dtype, but got Double and Float

To Reproduce

Train a PatchTSTEstimator:

estimator = PatchTSTEstimator(prediction_length=prediction_length, patch_len=16, context_length=12)
predictor = estimator.train(training_dataset)

Error message or code output

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 2
      1 estimator = PatchTSTEstimator(prediction_length=prediction_length, patch_len=16, context_length=12)
----> 2 predictor = estimator.train_model(training_dataset)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/gluonts/torch/model/estimator.py:209, in PyTorchLightningEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
    200 custom_callbacks = self.trainer_kwargs.pop("callbacks", [])
    201 trainer = pl.Trainer(
    202     **{
    203         "accelerator": "auto",
   (...)
    206     }
    207 )
--> 209 trainer.fit(
    210     model=training_network,
    211     train_dataloaders=training_data_loader,
    212     val_dataloaders=validation_data_loader,
    213     ckpt_path=ckpt_path,
    214 )
    216 if checkpoint.best_model_path != "":
    217     logger.info(
    218         f"Loading best model from {checkpoint.best_model_path}"
    219     )

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1252 def optimizer_step(
   1253     self,
   1254     epoch: int,
   (...)
   1257     optimizer_closure: Optional[Callable[[], Any]] = None,
   1258 ) -> None:
   1259     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1260     the optimizer.
   1261 
   (...)
   1289 
   1290     """
-> 1291     optimizer.step(closure=optimizer_closure)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    386         else:
    387             raise RuntimeError(
    388                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    389             )
--> 391 out = func(*args, **kwargs)
    392 self._optimizer_step_code()
    394 # call optimizer step post hooks

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/optim/adam.py:148, in Adam.step(self, closure)
    146 if closure is not None:
    147     with torch.enable_grad():
--> 148         loss = closure()
    150 for group in self.param_groups:
    151     params_with_grad = []

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99 
   (...)
    102 
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    124 @torch.enable_grad()
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
    312 trainer = self.trainer
    314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    316 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
    380 if self.model != self.lightning_module:
    381     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/gluonts/torch/model/patch_tst/lightning_module.py:64, in PatchTSTLightningModule.training_step(self, batch, batch_idx)
     60 def training_step(self, batch, batch_idx: int):  # type: ignore
     61     """
     62     Execute training step.
     63     """
---> 64     train_loss = self.model.loss(
     65         **select(self.inputs, batch),
     66         future_target=batch["future_target"],
     67         future_observed_values=batch["future_observed_values"],
     68     ).mean()
     70     self.log(
     71         "train_loss",
     72         train_loss,
   (...)
     75         prog_bar=True,
     76     )
     77     return train_loss

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/gluonts/torch/model/patch_tst/module.py:228, in PatchTSTModel.loss(self, past_target, past_observed_values, future_target, future_observed_values)
    221 def loss(
    222     self,
    223     past_target: torch.Tensor,
   (...)
    226     future_observed_values: torch.Tensor,
    227 ) -> torch.Tensor:
--> 228     distr_args, loc, scale = self(
    229         past_target=past_target, past_observed_values=past_observed_values
    230     )
    231     loss = self.distr_output.loss(
    232         target=future_target, distr_args=distr_args, loc=loc, scale=scale
    233     )
    234     return weighted_average(loss, weights=future_observed_values, dim=-1)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/gluonts/torch/model/patch_tst/module.py:206, in PatchTSTModel.forward(self, past_target, past_observed_values)
    203 inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1)
    205 # project patches
--> 206 enc_in = self.patch_proj(inputs)
    207 embed_pos = self.positional_encoding(enc_in.size())
    209 # transformer encoder with positional encoding

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.cache/pypoetry/virtualenvs/iaops-lab-hoqTDlRF-py3.12/lib/python3.12/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
    115 def forward(self, input: Tensor) -> Tensor:
--> 116     return F.linear(input, self.weight, self.bias)

Environment