Open seanyboi opened 3 years ago
I have the same question and quite confusing at the first beginning. I saw people using filtered data for data para in validation, or adding min_prediction_idx=training.index.time.max() + 1
when getting validation. After reading source code my understanding is:
TimeSeriesDataSet.from_dataset
will filter the last max_sequence_length
for each group series corresponding to the data passed to validation. The code is in _construct_index
if predict_mode: # keep longest element per series (i.e. the first element that spans to the end of the series)
# filter all elements that are longer than the allowed maximum sequence length
df_index = df_index[
lambda x: (x["time_last"] - x["time"] + 1 <= max_sequence_length)
& (x["sequence_length"] >= min_sequence_length)
]
# choose longest sequence
df_index = df_i
trainer.fit(
tft,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
the val_dataloaders will be used to calculate validation loss before training data.
So what validation we should use? I think depend on use case.
If you want to train based on training data, and want to validate on a fixed part of data, you should specify a fixed range with predict=False, e.g. validation = TimeSeriesDataSet.from_dataset(training, data[fixed_filtered_range_for_validation], predict=False, stop_randomization=True)
.
If you want to validate with a composition of last max_sequence_length of each group series, you can do as the example
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
.
You can also either specify min_prediction_idx para or just pass a filtered data to from_dataset to limit the validation.
As a test, if the data shape is (198, 13),
validation = TimeSeriesDataSet.from_dataset(training, data, predict=False, stop_randomization=True)
#you might see different num other than 45 depend on your data, as sequence is filtered by min_sequence_length
len(validation) = 45
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
len(validation) = 1
I was hoping someone could clear up the use of
from_dataset()
in the demand forecasting example using TFT.When creating the validation set using
from_dataset()
isn't it using the whole dataset that is used to create the training dataset and therefore leaking data into the validation set? Shouldn't the data used by the data after the training_cutoff be like sodata[lambda x: x.date > training_cutoff]
?