alkaline-ml / pmdarima

A statistical library designed to fill the void in Python's time series analysis capabilities, including the equivalent of R's auto.arima function.
https://www.alkaline-ml.com/pmdarima
MIT License
1.57k stars 231 forks source link

`auto_arima` fails to fit 414 seasonal time series in parallel #483

Open AzulGarza opened 2 years ago

AzulGarza commented 2 years ago

I'm trying to fit auto_arima for the M4-Hourly dataset. You can download it here. I'm using the following code,

import time
from functools import partial
from multiprocessing import Pool, cpu_count

import numpy as np
import pandas as pd
from pmdarima.arima import auto_arima

def fit_and_predict(index, ts, horizon, freq, seasonality): 
    x = ts['y'].values
    try:
        mod = auto_arima(x, m=seasonality,
                         with_intercept=False,
                         error_action='ignore')
        forecast = mod.predict(horizon)
    except:
        forecast = np.repeat(x[-1], horizon)

    forecast = pd.DataFrame({
        'ds': np.arange(ts['ds'].max() + 1, ts['ds'].max() + horizon + 1),
        'ypred': forecast
    })
    forecast['unique_id'] = index

    return forecast[['unique_id', 'ds', 'ypred']]

def main():
    train = pd.read_csv('M4-Hourly.csv')
    horizon = 48
    freq = 'H'
    seasonality = 24

    partial_fit_and_predict = partial(fit_and_predict, 
                                      horizon=horizon, freq=freq, seasonality=seasonality)
    start = time.time()
    print(f'Parallelism on {cpu_count()} CPU')
    with Pool(cpu_count()) as pool:
        results = pool.starmap(partial_fit_and_predict, train.groupby('unique_id'))
    end = time.time()
    print(end - start)

    forecasts = pd.concat(results)
    forecasts.columns = ['unique_id', 'ds', 'auto_arima_pmdarima']
    forecasts.to_csv(f'data/pmdarima-forecasts-{dataset}-{group}.csv', index=False)

    time_df = pd.DataFrame({'time': [end - start], 'model': ['auto_arima_pmdarima']})
    time_df.to_csv(f'data/pmdarima-time-{dataset}-{group}.csv', index=False)

if __name__ == '__main__':
    main()

But after several hours (like 20) the process gets stuck. I'm using a 96 core AWS instance.

This is the conda enviroment I'm using,

name: arima
channels:
  - conda-forge
  - defaults
  - anaconda
dependencies:
  - python=3.7
  - pip==20.3.3
  - numpy==1.21.4
  - scikit-learn
  - pmdarima

I tried to restrict the number of cores used by some processes using,

import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

But the same problem arises.