unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.57k stars 830 forks source link

Passing kwargs to the underlying fitting function #2438

Open DavidKleindienst opened 1 week ago

DavidKleindienst commented 1 week ago

Is your feature request related to a current problem? Please describe. I wanted to pass some additional arguments to Prophet's fit function (from now on called fit_kwargs), but this is not currently supported by darts. The only current possibility for the user to achieve this, is to subclass Prophet and overwrite the ._fit function

In general, there are multiple different strategies (depending on the model) in darts how passing through of fit_kwargs is currently handled. Looking through the code, I would summarize the current situation as follows (hope I haven't missed something):

Describe proposed solution I propose to unify the behavior (except for TorchForecastingModels which makes sense be treated differently) to support passing of fit_kwargs through the .fit function, i.e.:

I think having the argument passing in the .fit function rather than the constructor function is better for two reasons: 1) Models often also support kwargs that are passed to the underlying models constructor, making a distinction between constructor_kwargs and fit_kwargs necessary. That means at least one of them has do be passed as a dict, which feels unintuive. 2) I think the .fit method would be the more obvious place where users would look for the possibility to pass such kwargs.

Describe potential alternatives

Additional context I'm happy to prepare a PR for this issue, once it is decided which of these solutions should be implemented

dennisbader commented 4 days ago

Hi @DavidKleindienst, and thanks for this issue. It's true indeed that we should unify more in this matter. I like your proposed solution with opening up the fit kwargs in fit() (and removing it from ExponentialSmoothing constructor). As you say, this concerns all models except TorchForecastingModel.

You can go ahead with the PR 🚀 :)