SalesforceAIResearch / uni2ts

[ICML2024] Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
615 stars 50 forks source link

There are some questions about fine-tuning #12

Closed rainbownmm closed 2 months ago

rainbownmm commented 3 months ago

Thanks for the excellent code and model. I would like to ask some questions. As shown in the picture, the model has indeed improved a lot after fine-tuning, but it will appear like the 5th and 8th picture. Before the model was fine-tuned, although it was not It’s accurate, but this won’t happen. I think it’s overfitting, but I’ve adjusted the learning rate to the power of -9, but it still happens. Do you have any good suggestions? image

gorold commented 3 months ago

Hi, having a smaller learning rate doesn't really help with reducing overfitting. You could try reducing the maximum number of epochs to a small number, perhaps 1 - 5, increase weight decay, or dropout. You may also want to fix the patch size.

rainbownmm commented 2 months ago

Sorry to bother you again, I don’t understand how to evaluate mse, etc. in the code. For example, in the following code, how can I modify the predicted value and actual value of the model and calculate the mse and other parameters?

def plot_single(
   inp: dict,
   label: dict,
   forecast: Forecast,
   context_length: int,
   intervals: tuple[float, ...] = (0.5, 0.9),
   ax: Optional[plt.axis] = None,
   dim: Optional[int] = None,
   name: Optional[str] = None,
   show_label: bool = False,
  ):
 ax = maybe.unwrap_or_else(ax, plt.gca)

 target = np.concatenate([inp["target"], label["target"]], axis=-1)
 start = inp["start"]
 if dim is not None:
     target = target[dim]
     forecast = forecast.copy_dim(dim)

 index = pd.period_range(start, periods=len(target), freq=start.freq)
 ax.plot(
     index.to_timestamp()[-context_length - forecast.prediction_length :],
     target[-context_length - forecast.prediction_length :],
     label="target",
     color="black",
 )
 forecast.plot(
     intervals=intervals,
     ax=ax,
     color="blue",
     name=name,
     show_label=show_label,
 )
 ax.set_xticks(ax.get_xticks())
 ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
 ax.legend(loc="lower left")
rainbownmm commented 2 months ago

I want to evaluate the details of each prediction, here is the code how can I modify it, thanks in advance for the answer

gorold commented 2 months ago

Hey, please check out the evaluation scripts. This function is just to plot the forecasts. Alternatively, you can take the forecast object which contains the predictions to write your own evaluation script.