unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.92k stars 860 forks source link

Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift. #1121

Closed gdevos010 closed 1 year ago

gdevos010 commented 2 years ago

I've seen pretty good results using RevIN with larger datasets. I'm also waiting on a license.

dennisbader commented 1 year ago

Reopen as discussed here

alexcolpitts96 commented 1 year ago

@gdevos010 I will move it to PLForecastingModule as you suggested to make it easier to work with.

@dennisbader I already have an initial proof-of-concept working for probablistic forecasts. I need to add testing for it, but it should work properly.

Edit: I may actually want to hold firm on keeping RINorm in layer_norm_variants.py unless there is some design pattern I am not aware of which would make it easier to implement in all models. The authors themselves call it a layer normalization technique:

RevIN is symmetrically structured to return the original distribution information to the model output by scaling and shifting the output in the denormalization layer in an amount equivalent to the shifting and scaling of the input data in the normalization layer.

Adding it to PLForecastingModule isn't going to be a silver bullet. When PLForecastingModule.__init__ is called the model hasn't been initialized and is still missing key dimensional parameters. _create_model isn't called until much later. While it would be more work, I think it is going to be best to implement it on a model-by-model basis so that we aren't hiding anything outside of the models.

I will add it to TiDE and see if any better way to integrate it to all models pops up.

dennisbader commented 1 year ago

Sounds great @alexcolpitts96. I assume the current implementation of TiDE is the proof-of-concept?

PLForecastingModule.__init__ is actually called when calling _create_model. So we could add it there (at that time we should also know the input_dim required to instantiate the norm).