Open TomSteenbergen opened 3 years ago
I am not planning for an out of memory version very soonish because the vast majority of datasets should fit into memory (particularly if you use a cloud server). Indeed, only encoders and scalers need to be prefit. Your approach looks reasonable to me. A more integrated solution could be to wrap the multiple datasets in one dataset. Could be an interesting PR.
First of all, thanks for the amazing work on this package!
I have a question on batch-wise training. I'd like to use the Temporal Fusion Transformer model on a very large data set. However, I cannot load this data set in-memory due to its size. Therefore, I would like to be able to train the TFT model batch-wise. I.e. fetch a single batch from the database, perform all necessary preprocessing on this batch, and perform one feedforward and backward pass. There should be only one batch in-memory at any time.
I couldn't find a clear example in the docs. However, in the docs for
TimeSeriesDataset
, I found the following note:So there is currently no in-built method to do this.
First question: Is this anywhere on the roadmap for the short term?
Second question: If not, would something like the code snippet below work? Can I repeatedly call
TimeSeriesDataSet
andtrainer.fit
, or should I use a custom loss function and optimizer in a loop like this? Also, besides encoders and scalers, are there any other things that will need to prefitted before passing toTimeSeriesDataset
? Lastly: is there an overall better approach to this problem?validation_data = db_conn.fetch_all(...) # Fetch all data of validation set. validation_dataset = TimeSeriesDataset(validation_data, ...) validation_dataloader = validation_dataset.to_dataloader(...)
trainer = Trainer(...) tft = TemporalFusionTransformer(...)
for batch in db_conn.fetch_batches(...): # Fetch training batch from database using some generator. train_dataset = TimeSeriesDataSet(batch, ...) train_dataloader = train_dataset.to_dataloader(...) trainer.fit(tft, train_dataloader, validation_dataloader)