ibm-granite / granite-tsfm

Foundation Models for Time Series
Apache License 2.0
345 stars 166 forks source link

Support prediction_filter_length when plotting #137

Open lycheesodaa opened 6 days ago

lycheesodaa commented 6 days ago

Thanks for amazing work so far.

It's a bit of a minor issue, but I'm running into an error when plotting with the plot_predictions() function after having passed prediction_filter_length=48 to the model instantiation. I'm following the example notebook, with the exact same code for zero-shot forecasting, but I get this:

Traceback (most recent call last):
  File "home/granite-tsfm/run_demand.py", line 86, in <module>
    zeroshot_eval(
  File "home/granite-tsfm/run_demand.py", line 78, in zeroshot_eval
    plot_predictions(
  File "home/granite-tsfm/tsfm_public/toolkit/visualization.py", line 379, in plot_predictions
    axs[i].plot(ts_y, y, label="True", linestyle="-", color="blue", linewidth=2)
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_axes.py", line 1779, in plot
    lines = [*self._get_lines(self, *args, data=data, **kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py", line 296, in __call__
    yield from self._plot_args(
               ^^^^^^^^^^^^^^^^
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py", line 486, in _plot_args
    raise ValueError(f"x and y must have same first dimension, but "
ValueError: x and y must have same first dimension, but have shapes (144,) and (192,)

It seems like the plot_predictions() function has yet to support reducing the horizon length, and editing the following line seems to fix the issue:

        else:
            batch = dset[index]
            ts_y_hat = np.arange(plot_context, plot_context + prediction_length)
            y_hat = predictions_subset[i]

            ts_y = np.arange(plot_context + prediction_length)
            y = batch["future_values"][:prediction_length, channel].squeeze().numpy() # <- edited line 369
            x = batch["past_values"][-plot_context:, channel].squeeze().numpy()
            y = np.concatenate((x, y), axis=0)
            border = plot_context
            plot_title = f"Example {indices[i]}"

I have only experimented with the case where dset and model are provided to the function, not the other cases, where it might be working just fine.


Edit: it seems like passing a string for the channel argument also doesn't seem to work in this case either, but I presume that's still an unimplemented feature?

wgifford commented 4 days ago

@lycheesodaa Thanks for the issue -- we will look into it.