unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.05k stars 878 forks source link

[FEAT] Add tsmixer-basic #2510

Open eschibli opened 2 months ago

eschibli commented 2 months ago

tsmixer was original reported as two different models, tsmixer-basic (which allows for past covariates and is called simply tsmixer in the paper) and tsmixer-ext, which allows for past, future, and static covariates. All results in the paper except for the m5 dataset used tsmixer-basic. The darts implementation is based on tsmixer-ext.

However, tsmixer-ext isn't identical to tsmixer-basic when there are no static or future covariates. The key difference is that while tsmixer-basic projects to output_chunk_length in the final layer, effectively encoding the historical data while maintaining it's time dimension, tsmixer-ext projects the historical and static data to output_chunk_length in the first layer. I don't think this is optimal as this will limit the usefulness of the residual connections. My testing with the original google-research source code shows a performance regression of about 10% higher MAE and MSE with the weather dataset when moving the temporal project step to the top of the model.

If the maintainers think this would be valuable I can implement this. I think the most sensible way to do so would be to add a project_first=True keyword.

madtoinou commented 2 months ago

Hi @eschibli,

If you think that you can elegantly make the tsmixer-basic architecture easily available through the TSMixerModel API/constructor, it would be for sure valuable to have a variation of this model that performs better when no future covariates are available, which can occur in many situations. I would maybe just call the argument first_layer_projection instead of just project_first, but we can discuss it in your PR.

You will also need to add checks in the fit() method so that an error is raised if first_layer_projection=False and future/static covariates are provided.