This is a large refactoring PR and open for discussion. The main goal of the PR is to unify API across different model types, and unify loss functions across different loss types.
Refactoring:
Fuses BaseWindows, BaseMultivariate and BaseRecurrent into BaseModel, removing the need for separate classes and unifying model API across different model types. Instead, this PR introduces two model attributes, yielding four possible model options: RECURRENT (True/False) and MULTIVARIATE (True/False). We currently have a model for every combination except a recurrent multivariate model (e.g. a multivariate LSTM), however this is now relatively simple to add. In addition, this change allows to have models that can be recurrent or not, or multivariate or not on-the-fly, based on users' input. This also allows for easier modelling going forward.
Unifies model API across all models, adding missing input variables to all model types.
Moves loss.domain_map outside of models to BaseModel
Moves RevINMultivariate used by TSMixer, TSMixerx and RMoK to common.modules
Features:
All losses compatible with all types of models (e.g. univariate/multivariate, direct/recurrent) OR appropriate protection added.
DistributionLoss now supports the use of quantile in predict, allowing for easy quantile retrieval for all DistributionLosses.
Mixture losses (GMM, PMM and NBMM) now support learned weights for weighted mixture distribution outputs.
Mixture losses now support the use of quantile in predict, allowing for easy quantile retrieval.
Improved stability of ISQF by adding softplus protection around some parameters instead of using .abs
Bug fixes:
MASE loss now works.
Added various protections around parameter combinations that are invalid (e.g. regarding losses)
StudentT increase default DoF to 3 to reduce unbound variance issues.
All models are now included in the tests; in most models we included eval: false on the examples whilst not having any other tests, causing most models to effectively not being tested
Breaking changes:
Rewrite of all recurrent models to get rid of the quadratic (in the sequence dimension) space complexity. As a result, it is impossible to load a recurrent model from a previous version into this version.
Recurrent models now require an input_size to be given.
TCN and DRNN are now windows models, not recurrent models.
Tests:
Added common._model_checks.py that includes a model testing function.
Todo:
[ ] Test models on speed/scaling as compared to current implementation across a set of datasets.
[x] Make sure docstring of all multivariate models is updated to reflect the additional inputs
This is a large refactoring PR and open for discussion. The main goal of the PR is to unify API across different model types, and unify loss functions across different loss types.
Refactoring:
BaseWindows
,BaseMultivariate
andBaseRecurrent
intoBaseModel
, removing the need for separate classes and unifying model API across different model types. Instead, this PR introduces two model attributes, yielding four possible model options:RECURRENT
(True
/False
) andMULTIVARIATE
(True
/False
). We currently have a model for every combination except a recurrent multivariate model (e.g. a multivariate LSTM), however this is now relatively simple to add. In addition, this change allows to have models that can be recurrent or not, or multivariate or not on-the-fly, based on users' input. This also allows for easier modelling going forward.domain_map
functions.loss.domain_map
outside of models toBaseModel
TSMixer
,TSMixerx
andRMoK
tocommon.modules
Features:
DistributionLoss
now supports the use ofquantile
inpredict
, allowing for easy quantile retrieval for all DistributionLosses.GMM
,PMM
andNBMM
) now support learned weights for weighted mixture distribution outputs.quantile
inpredict
, allowing for easy quantile retrieval.ISQF
by adding softplus protection around some parameters instead of using.abs
Bug fixes:
MASE
loss now works.StudentT
increase default DoF to 3 to reduce unbound variance issues.eval: false
on the examples whilst not having any other tests, causing most models to effectively not being testedBreaking changes:
input_size
to be given.TCN
andDRNN
are now windows models, not recurrent models.Tests:
common._model_checks.py
that includes a model testing function.Todo: