Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
3.02k stars 348 forks source link

example for `TimeSeriesDataset` #317

Closed deven-gqc closed 1 year ago

deven-gqc commented 1 year ago

Hello,

I'm trying to use the TFT model on my custom dataset. For that, I created a custom pytorch dataloader and when I try to call the fit method on it, it tells me that it got an unexpected type.

Here is the construction of my dataloader

class BSM2Dataset(Dataset):
  """BSM2 dataset."""

  def __init__(self):
    # self.landmarks_frame = pd.read_csv(csv_file)
    self.inf = pd.read_csv('/content/bsm2_influent.csv').values
    self.eff = pd.read_csv('/content/bsm2_effluent.csv').values
    self.X, self.Y = self.split_series(1, 1)

  def __len__(self):
    return len(self.inf)

  def split_series(self, n_past, n_future):
    X, y = list(), list()
    for window_start in range(len(self.inf)):
      past_end = window_start + n_past
      future_end = past_end + n_future
      if future_end > len(self.inf):
        break
      # slicing the past and future parts of the window
      past, future = self.inf[window_start:past_end, :], self.eff[past_end - 1:future_end -1, :]
      X.append(past)
      y.append(future)
    return np.array(X), np.array(y)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    return self.X[idx], self.Y[idx]

Here the call to the fit method

foo = BSM2Dataset()
model = TFT(h=12,
                input_size=48,
                hidden_size=100,
                stat_exog_list=['airline1'],
                hist_exog_list=['y_[lag12]'],
                futr_exog_list=['trend'],
                max_epochs=300,
                learning_rate=0.01,
                scaler_type='robust',
                loss=MQLoss(level=[80, 90]),
                windows_batch_size=None,
                enable_progress_bar=True)
model.fit(foo)

The call fails and here is the entire stacktrace

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-14-a8a91adbe35a>](https://localhost:8080/#) in <module>
----> 1 model.fit(foo)

18 frames
[/usr/local/lib/python3.7/dist-packages/neuralforecast/common/_base_windows.py](https://localhost:8080/#) in fit(self, dataset, val_size, test_size)
    399 
    400         trainer = pl.Trainer(**self.trainer_kwargs)
--> 401         trainer.fit(self, datamodule=datamodule)
    402 
    403     def predict(self, dataset, test_size=None, step_size=1, **data_module_kwargs):

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    769         self.strategy.model = model
    770         self._call_and_handle_interrupt(
--> 771             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    772         )
    773 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    721                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    722             else:
--> 723                 return trainer_fn(*args, **kwargs)
    724         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    725         except KeyboardInterrupt as exception:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    809             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    810         )
--> 811         results = self._run(model, ckpt_path=self.ckpt_path)
    812 
    813         assert self.state.stopped

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
   1234         self._checkpoint_connector.resume_end()
   1235 
-> 1236         results = self._run_stage()
   1237 
   1238         log.detail(f"{self.__class__.__name__}: trainer tearing down")

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self)
   1321         if self.predicting:
   1322             return self._run_predict()
-> 1323         return self._run_train()
   1324 
   1325     def _pre_training_routine(self):

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_train(self)
   1343 
   1344         with isolate_rng():
-> 1345             self._run_sanity_check()
   1346 
   1347         # enable train mode

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_sanity_check(self)
   1411             # run eval step
   1412             with torch.no_grad():
-> 1413                 val_loop.run()
   1414 
   1415             self._call_callback_hooks("on_sanity_check_end")

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py](https://localhost:8080/#) in advance(self, *args, **kwargs)
    153         if self.num_dataloaders > 1:
    154             kwargs["dataloader_idx"] = dataloader_idx
--> 155         dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    156 
    157         # store batch level output per dataloader

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py](https://localhost:8080/#) in advance(self, data_fetcher, dl_max_batches, kwargs)
    110         if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    111             batch_idx = self.batch_progress.current.ready
--> 112             batch = next(data_fetcher)
    113         else:
    114             batch_idx, batch = next(data_fetcher)

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/fetching.py](https://localhost:8080/#) in __next__(self)
    182 
    183     def __next__(self) -> Any:
--> 184         return self.fetching_function()
    185 
    186     def reset(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/fetching.py](https://localhost:8080/#) in fetching_function(self)
    257             # this will run only when no pre-fetching was done.
    258             try:
--> 259                 self._fetch_next_batch(self.dataloader_iter)
    260                 # consume the batch we just fetched
    261                 batch = self.batches.pop(0)

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/fetching.py](https://localhost:8080/#) in _fetch_next_batch(self, iterator)
    271     def _fetch_next_batch(self, iterator: Iterator) -> None:
    272         start_output = self.on_fetch_start()
--> 273         batch = next(iterator)
    274         self.fetched += 1
    275         if not self.prefetch_batches and self._has_len:

[/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py](https://localhost:8080/#) in __next__(self)
    679                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    680                 self._reset()  # type: ignore[call-arg]
--> 681             data = self._next_data()
    682             self._num_yielded += 1
    683             if self._dataset_kind == _DatasetKind.Iterable and \

[/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py](https://localhost:8080/#) in _next_data(self)
    719     def _next_data(self):
    720         index = self._next_index()  # may raise StopIteration
--> 721         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    722         if self._pin_memory:
    723             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

[/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py](https://localhost:8080/#) in fetch(self, possibly_batched_index)
     50         else:
     51             data = self.dataset[possibly_batched_index]
---> 52         return self.collate_fn(data)

[/usr/local/lib/python3.7/dist-packages/neuralforecast/tsdataset.py](https://localhost:8080/#) in _collate_fn(self, batch)
     61                         temporal_cols = elem['temporal_cols'])
     62 
---> 63         raise TypeError(f'Unknown {elem_type}')
     64 
     65 # %% ../nbs/tsdataset.ipynb 7

TypeError: Unknown <class 'tuple'>

I tried searching on the site and I couldn't find an example of the TimeSeriesDataset being used as an example. Would it possible to have a tutorial nb which shows how to do this?

Thanks!

kdgutier commented 1 year ago

Hi @deven-gqc,

Thanks for the comment, we are yet to add a tutorial with usage examples of the TimeSeriesDataset.

You can take a look to:

Let us know if you manage to run TFT in your data.

deven-gqc commented 1 year ago

Hi @kdgutier, thanks for the prompt reply. I went through those, I'm not very familiar with pytorch lightning but it seems very similar, I will give that a try. Also, as a side question, does Nixtla has a forums page or a discord server? If so, I'd like to join it.

kdgutier commented 1 year ago

Cool, cool. Let me know how TFT goes.

We have a slack channel. Does this link works?

deven-gqc commented 1 year ago

@kdgutier I can't join the Slack community as I don't have an email address with these domains. image

mergenthaler commented 1 year ago

Use this link to join.

deven-gqc commented 1 year ago

Thanks @mergenthaler, I am in!

Ftrejo23 commented 1 year ago

@deven-gqc were you ever able to get this working? I'm doing something similar but not sure how to use TImeSeriesDataset. I'm thinking it might just be easier to use NeuralForecast class and use 1 model in the params.

deven-gqc commented 1 year ago

@Ftrejo23 Not really. I didn't get a chance to explore it further.