jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.75k stars 599 forks source link

Training TFT on a simple signal/target problem results in no convergence #753

Open void-t-light-87 opened 2 years ago

void-t-light-87 commented 2 years ago

Hi,

I've been going through a set of smoke tests of the TFT model to see if it would work for my use case. I am able to train it on simple univariate sinusoids but any multivariate problem seems to fall apart immediately. For example the attached file just has two time series. One is called the signal and one the target. Every time the signal switches from 0 to 1 the target increases in magnitude. Otherwise the target stays as 1. An LSTM will quickly pick up on this this pattern and make a prediction but this model ignores it completely and continues to predict 1s. Is there something I should do differently? Attached is one of the examples I tried to run to train the model. I did try optuna optimizations as well and none converged so far.

data = pd.read_csv('path_to_ex_csv_here')

data.reset_index(inplace=True)
data["time_idx"] = data.index
# data["month"] = data.date.dt.month.astype(str).astype("category")

max_prediction_length = 30 
max_encoder_length = 30 
training_cutoff = data["time_idx"].max() - (max_prediction_length)

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="target",
    group_ids=["wave_label"],
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["wave_label"],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "target",
        "signal"
    ],
    target_normalizer=EncoderNormalizer(),
    add_relative_time_idx=True,
[ex.csv](https://github.com/jdb78/pytorch-forecasting/files/7535930/ex.csv)
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
batch_size = 32
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

pl.seed_everything(42)
trainer = pl.Trainer(
    default_root_dir=forecasting_tools_constants.TFT_MODEL_CHECKPOINTS_DIR,
    gpus=[0],
    gradient_clip_val=0.1,
)

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False,
                                    mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger('logger_path',
                           name="model_name")  # logging results to a tensorboard

trainer = pl.Trainer(
    default_root_dir="checkpoint_dir",
    max_epochs=30,
    gpus=[0],
    weights_summary="top",
    gradient_clip_val=0.1,
    limit_train_batches=30,
    # fast_dev_run=True,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=30,
    attention_head_size=10,
    dropout=0.1,
    hidden_continuous_size=30,
    output_size=7,  # 7 quantiles by default
    loss=QuantileLoss(),
    log_interval=10,
    # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=4,
)

trainer.fit(
    tft,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)

Thanks

void-t-light-87 commented 2 years ago

ex.csv

edcxan commented 2 years ago

Your validation data has no signal.

void-t-light-87 commented 2 years ago

Your validation data has no signal.

Would you mind pointing out how it should be different? I was wondering about this from the tutorial. How does training work for TFT? Is it only using the validation data to compute error for training? Thanks a lot

edcxan commented 2 years ago

https://www.geeksforgeeks.org/splitting-data-for-machine-learning-models/

void-t-light-87 commented 2 years ago

https://www.geeksforgeeks.org/splitting-data-for-machine-learning-models/

Then do you know what https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html is supposed to do? In that example you only validate on the last sample as well.

void-t-light-87 commented 2 years ago

https://www.geeksforgeeks.org/splitting-data-for-machine-learning-models/

Do you have any more input? I thought the TFT was using the encoder/prediction windows to train. So the validation of the network is the prediction window. In the tutorial the validation seems to be used as the test set, not as the dataset the model is using to train. You appear to be claiming that it is using the validation set used by val_dataloaders to train. Do you know this to be true?

edcxan commented 2 years ago

The tutorial data contains thousands of samples and multiple time series in the validation set.

void-t-light-87 commented 2 years ago

The tutorial data contains thousands of samples and multiple time series in the validation set.

It has exactly one set of samples per agency/sku combination in the validation set, of which there are 350. My validation set has 1 since I do not have multiple combinations of static categorical in my dataset. It sounds like you know how TFT is supposed to be set up to handle the example I am describing. Would you mind sharing the code that predicts the spikes on your machine? I'd like to see TFT do what any LSTM would do and show that it can predict that the target will go up after the signal rises. I did try to test on a dataset 100x larger and with 80/20 splits between test/validation. None work so far.

Thanks a lot