jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.75k stars 599 forks source link

Single Element List target Parameter leads to error #1190

Open JustusMzB opened 1 year ago

JustusMzB commented 1 year ago

Expected behavior

I executed TimeSeriesDataSet with a len 1 list as the target parameter. I expected a TimeSeriesDataSet to be initialized.

Actual behavior

However, result was the following Error:

Traceback (most recent call last):
  File "/home/justus/PycharmProjects/tft_gapfilling/prototyping/barebones_weather_tft.py", line 35, in <module>
    tft = TemporalFusionTransformer.from_dataset(weather_dataset) # See what happens in this non-manual approach
  File "/home/justus/anaconda3/envs/tft_gapfilling/lib/python3.9/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py", line 356, in from_dataset
    return super().from_dataset(
  File "/home/justus/anaconda3/envs/tft_gapfilling/lib/python3.9/site-packages/pytorch_forecasting/models/base_model.py", line 1484, in from_dataset
    return super().from_dataset(dataset, **new_kwargs)
  File "/home/justus/anaconda3/envs/tft_gapfilling/lib/python3.9/site-packages/pytorch_forecasting/models/base_model.py", line 1000, in from_dataset
    assert isinstance(
AssertionError: multiple targets require loss to be MultiLoss but found QuantileLoss(quantiles=[0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98])

I think it has to do with the way in which a dataset establishes wether it is in multi target mode: TimeSeriesDataSet.multi_target solely checks wether the parameter is a list or tuple, not it's len.

Code to reproduce the problem

weather_dataset = TimeSeriesDataSet(
    weather_df,
    time_idx="time_id",
    group_ids=["group_id"],
    target=["WetBulbCelsius"], # Careful: Wether the Dataset considers itself multi-target depends solely on wether this is a list.
    time_varying_known_reals=["time_id"],
    time_varying_unknown_reals=['WetBulbCelsius', "Visibility", 'DryBulbCelsius', 'DewPointCelsius', 'RelativeHumidity',
                                'WindSpeed', 'WindDirection', 'StationPressure', 'Altimeter'],
    max_encoder_length=30,
    min_encoder_length=10,
    max_prediction_length=10,
    min_prediction_length= 1,
    allow_missing_timesteps=True)

Paste the command(s) you ran and the output. Including a link to a colab notebook will speed up issue resolution. If there was a crash, please include the traceback here. The code used to initialize the TimeSeriesDataSet and model should be also included.

GNuzzarello commented 1 year ago

Just put the target without insert it inside a list.

weather_dataset = TimeSeriesDataSet(
    weather_df,
    time_idx="time_id",
    group_ids=["group_id"],
    target="WetBulbCelsius", # Careful: Wether the Dataset considers itself multi-target depends solely on wether this is a list.
    time_varying_known_reals=["time_id"],
    time_varying_unknown_reals=['WetBulbCelsius', "Visibility", 'DryBulbCelsius', 'DewPointCelsius', 'RelativeHumidity',
                                'WindSpeed', 'WindDirection', 'StationPressure', 'Altimeter'],
    max_encoder_length=30,
    min_encoder_length=10,
    max_prediction_length=10,
    min_prediction_length= 1,
    allow_missing_timesteps=True)