sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.96k stars 627 forks source link

Training a TemporalFusionTransformer batch-wise #359

Open TomSteenbergen opened 3 years ago

TomSteenbergen commented 3 years ago

First of all, thanks for the amazing work on this package!

I have a question on batch-wise training. I'd like to use the Temporal Fusion Transformer model on a very large data set. However, I cannot load this data set in-memory due to its size. Therefore, I would like to be able to train the TFT model batch-wise. I.e. fetch a single batch from the database, perform all necessary preprocessing on this batch, and perform one feedforward and backward pass. There should be only one batch in-memory at any time.

I couldn't find a clear example in the docs. However, in the docs for TimeSeriesDataset, I found the following note:

Large datasets:

Currently the class is limited to in-memory operations (that can be sped up by an existing installation of numba). If you have extremely large data, however, you can pass prefitted encoders and and scalers to it and a subset of sequences to the class to construct a valid dataset (plus, likely the EncoderNormalizer should be used to normalize targets). when fitting a network, you would then to create a custom DataLoader that rotates through the datasets. There is currently no in-built methods to do this.

So there is currently no in-built method to do this.

validation_data = db_conn.fetch_all(...) # Fetch all data of validation set. validation_dataset = TimeSeriesDataset(validation_data, ...) validation_dataloader = validation_dataset.to_dataloader(...)

trainer = Trainer(...) tft = TemporalFusionTransformer(...)

for batch in db_conn.fetch_batches(...): # Fetch training batch from database using some generator. train_dataset = TimeSeriesDataSet(batch, ...) train_dataloader = train_dataset.to_dataloader(...) trainer.fit(tft, train_dataloader, validation_dataloader)



Many thanks in advance!
jdb78 commented 3 years ago

I am not planning for an out of memory version very soonish because the vast majority of datasets should fit into memory (particularly if you use a cloud server). Indeed, only encoders and scalers need to be prefit. Your approach looks reasonable to me. A more integrated solution could be to wrap the multiple datasets in one dataset. Could be an interesting PR.