jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.84k stars 608 forks source link

Feature Request: TimeSeriesDataset retrieve column/series names #1003

Open LuisPerezVazquez opened 2 years ago

LuisPerezVazquez commented 2 years ago

Expected behavior

Be able to easily retrieve time series name/ group_id after prediction

Actual behavior

Right now it is not clear or easy to retrieve column names from torch tensor after transforming data into pytorch forecasting TimeSeriesDataset. This is crucial to be able to use this library and interpret the results in a more straightforward way.

Code to reproduce the problem

Any model.predict(val_dataloader) will give a bunch of columns (one per time series) with their prediction, but it is not trivial to retrieve which column is which times series (group_id)

DominikPKaiser commented 2 years ago

Hi Luis. I just use the built-in filter function to make specific predictions.

raw_prediction, x = model.predict(validation.filter(lambda x: (x.COMPANY_ID == "APPLE")), mode="raw", return_x=True, )