Open matsuobasho opened 1 month ago
@matsuobasho
This worked for me:
import pandas as pd
from transformers import PatchTSTForPrediction
from tsfm_public.toolkit.time_series_forecasting_pipeline import TimeSeriesForecastingPipeline
from tsfm_public.toolkit.util import select_by_index
dataset="ETTh1"
dataset_path = f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv"
timestamp_column = "date"
id_columns = []
# forecast_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
forecast_columns = ["HUFL"]
df = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)
context_length = 512
forecast_length = 96
fewshot_fraction = 0.05
model_path = "ibm-granite/granite-timeseries-patchtst"
model = PatchTSTForPrediction.from_pretrained(model_path, num_input_channels=len(forecast_columns))
forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=forecast_columns,
freq="1h",
)
forecasts = forecast_pipeline(df[:512])
forecasts.head()
You may find the "transfer" notebooks helpful: https://github.com/ibm-granite/granite-tsfm/blob/main/notebooks/hfdemo/patch_tst_transfer.ipynb
Thank you, this is useful and seems to have worked.
The reason why I had to create this issue is because the transformers library for specific applications is often missing nuances in the documentation that severely hampers use and understanding. Specifically, in this particular case, even if I knew to check the PatchTst page on HuggingFace, I would see that the default num_input_channels
is already set to 1. So then even if I then knew to use this argument in the from_pretrained
function, why would I need to specify the default value again?
I'm following the
test_forecasting_pipeline_forecasts
function here to run a prediction pipeline on my dataset.On the last line, I get the following error:
ValueError: The defined number of input channels (7) in the config has to be the same as the number of channels in the batch input (1)
I'm unable to share my dataset but it's univariate with 2 columns -
['Timestamp', 'y']
at 15 minute intervals.