unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.57k stars 829 forks source link

Add TSMixer model #1807

Open alexcolpitts96 opened 1 year ago

alexcolpitts96 commented 1 year ago

I recently found TSMixer (http://arxiv.org/abs/2303.06053).

It is very similar to TiDE (#1726) but with a few tweaks.

It should be pretty straight forward to implement based on the implementation of TiDE (#1727).

I will try to get started on it in the next few days.

alexcolpitts96 commented 1 year ago

Google Research implementation: https://github.com/google-research/google-research/blob/master/tsmixer/tsmixer_basic/models/tsmixer.py

Details in the paper aren't great; however, the source code clears things up.

joshua-xia commented 12 months ago

@alexcolpitts96 did you get the paper implement tsmixer_extended? it seems support past/static/future covariable features

alexcolpitts96 commented 12 months ago

I have ran into a few things with the implementation and had some other PRs that I needed to cleanup.

I managed to implement reversible instance normalization, but there is a bug in the tests that only happens during the build process within Github.

The rest of the model is pretty straightforward, I just need to find the time. I just started a new job so I am a little short on time as of late.

meteoDaniel commented 9 months ago

Recently Google published a paper and an article on TSMixer: https://blog.research.google/2023/09/tsmixer-all-mlp-architecture-for-time.html

@alexcolpitts96 do you have started with a pytorch implementation that can fit into darts?

alexcolpitts96 commented 9 months ago

I started working on it roughly two months ago. I have been busy wrapping up school and starting a new job. I should have some time to clean it up over the next few weeks.

I managed to get the skeleton written, but I still need to add covariates and probabilistic forecasting.

https://github.com/alexcolpitts96/darts/blob/tsmixer/darts/models/forecasting/tsmixer_model.py

meteoDaniel commented 9 months ago

From my point of view that looks good. Why do you think you need probabilistic forecasting? Does TSMixer provide it by nature? Within tft , probabilistic forecast is a result of the quantile loss function. Maybe I am wrong but in case you want to add this feature to TSMixer, I think you just need to run it with QuantileLoss.

thijsjls commented 7 months ago

@alexcolpitts96 Did you have any time to work on this further? Would be interested in using this model. Also open to contribute.

StatMixedML commented 5 months ago

IBM has released its version of the PatchTSMixer on HuggingFace. Maybe this helps to have it available in darts soon

candalfigomoro commented 4 months ago

Pay attention to the fact that there are apparently 2 different models named "TSMixer":

leoniewgnr commented 4 months ago

@alexcolpitts96 @meteoDaniel @thijsjls Hi everyone, I've looked into your code @alexcolpitts96 and it looks really good! I've tried it, using the following code, including lists of timeseries, covariates, encoders:

model_params = {
        "input_chunk_length": 240,  # hist_len
        # not tuned
        "use_static_covariates": False,
        "output_chunk_length": 37,  # pred_len
        "n_epochs": n_epochs,
    }

    model = TSMixerModel(
        **model_params,
        pl_trainer_kwargs={
          "accelerator": "auto",
          "devices":"auto"
        },
        add_encoders = {
          'datetime_attribute': {'past': ['hour', 'day_of_week', 'month'],'future': ['hour', 'day_of_week', 'month']},
          'transformer': Scaler(),
        },
        model_name = 'tsmixer',
        save_checkpoints=True,
        force_reset=True
    )

    model.fit(ts_train_scaled_list,
              future_covariates=cov_list,
              val_series = ts_val_scaled_list,
              val_future_covariates = cov_list,
              verbose=False)

    #load best model on validation set to avoid overfitting
    model = TSMixerModel.load_from_checkpoint(model_name = 'tsmixer', best=True)

and it works great! Only thing I had to change in you code is still a old import statement from skicit-learn, which is removed from the current darts version, so just merging with the newest darts version, should resolve it.

I would really appreciate it if you go forward and push this as I really would like to use it and results are so good from TSMixer. Thank you very much! I'm also very happy to help!

cristof-r commented 3 months ago

I made a PR as the above seems to have gone stale. Any feedback is welcome!