alkaline-ml / pmdarima

A statistical library designed to fill the void in Python's time series analysis capabilities, including the equivalent of R's auto.arima function.
https://www.alkaline-ml.com/pmdarima
MIT License
1.6k stars 234 forks source link

Models history is kept during auto_arima training unnecessarily causing OOMs #574

Open tRosenflanz opened 8 months ago

tRosenflanz commented 8 months ago

Describe the bug

Auto Arima keeps all model history during training/solving even when return_valid_fits is set to False. This causes memory consumption growth. Unless return_valid_fits is set to True, only the best model can be kept during solving and worse models can be discarded. This drastically reduces the memory footprint which is useful when parallel training many models

To Reproduce

Train any auto_arima with enough iterations. Observe memory growth.

Versions

System:
    python: 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:38)  [GCC 7.3.0]
executable: /home/ubuntu/anaconda3/envs/new_torch/bin/python
   machine: Linux-5.4.0-1063-aws-x86_64-with-glibc2.10

Python dependencies:
 setuptools: 67.6.1
        pip: 23.0.1
    sklearn: 1.3.0
statsmodels: 0.14.0
      numpy: 1.24.2
      scipy: 1.10.1
     Cython: 3.0.8
     pandas: 2.0.3
     joblib: 1.3.2
   pmdarima: 2.0.4
Linux-5.4.0-1063-aws-x86_64-with-glibc2.10
Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:38) 
[GCC 7.3.0]
pmdarima 2.0.4
NumPy 1.24.4
SciPy 1.10.1
Scikit-Learn 1.3.0

Expected Behavior

Memory consumption of auto_arima stays near constant or grows very slowly.

Actual Behavior

Memory consumption grows quickly and massively causing OOMs

Additional Context

This can be fixed by conditionally saving the worse models and discarding previous best fit. Monkey patching locally works to drastically reduce memory consumption

e.g. for stepwise solver adding self.return_all_fits property and modifying fit function:

def _do_fit(self, order, seasonal_order, constant=None):
    """Do a fit and determine whether the model is better"""
    if not self.seasonal:
        seasonal_order = (0, 0, 0, 0)
    seasonal_order = sm_compat.check_seasonal_order(seasonal_order)

    # we might be fitting without a constant
    if constant is None:
        constant = self.with_intercept

    if (order, seasonal_order, constant) not in self.ic_dict:

        # increment the number of fits
        self.k += 1

        fit, fit_time, new_ic = self._fit_arima(
            order=order,
            seasonal_order=seasonal_order,
            with_intercept=constant)

        # use the orders as a key to be hashed for
        if self.return_all_fits:
                self.results_dict[(order, seasonal_order, constant)] = fit
        # the dictionary (pointing to fit)
        # cache this so we can lookup best model IC downstream
        self.ic_dict[(order, seasonal_order, constant)] = new_ic
        self.fit_time_dict[(order, seasonal_order, constant)] = fit_time

        # Determine if the new fit is better than the existing fit
        if fit is None or np.isinf(new_ic):
            return False

        # no benchmark model
        if self.bestfit is None:
            self.bestfit = fit
            self.bestfit_key = (order, seasonal_order, constant)

            if self.trace > 1:
                print("First viable model found (%.3f)" % new_ic)
            return True

        # otherwise there's a current best
        current_ic = self.ic_dict[self.bestfit_key]
        if new_ic < current_ic:

            if self.trace > 1:
                print("New best model found (%.3f < %.3f)"
                        % (new_ic, current_ic))
            if not self.return_all_fits:
                self.results_dict[self.bestfit_key] = None
            self.bestfit = fit
            self.bestfit_key = (order, seasonal_order, constant)
            self.results_dict[self.best_fit_key] = fit
            return True

    # we've seen this model before
    return False