Nixtla / statsforecast

Lightning ⚡️ fast forecasting with statistical and econometric models.
https://nixtlaverse.nixtla.io/statsforecast
Apache License 2.0
3.89k stars 272 forks source link

remove numba dependency #753

Open jmoralez opened 8 months ago

jmoralez commented 8 months ago

Description

We've heavily relied on numba to speed up our models, however we don't have the need for its JIT compilation, since the code that uses it is defined inside the library.

Replacing numba jitted code with a compiled alternative (C++ or Rust for example) would provide the following benefits:

And the following drawbacks:

Use case

This will benefit the development process since even when using the cache it can take a couple of seconds to run jitted functions for the first time. Also deployments would be smoother because either:

  1. People that were copying over the cache won't have to do it anymore.
  2. People that weren't copying the cache won't experience cold starts anymore.
AzulGarza commented 8 months ago

i love this. i think we should move forward with the enhancement.

regarding the second drawback, i agree with the nightly wheels option. also adopting a compiled alternative will force us to make releases more often. 🙌

regarding the first drawback, we could still have some models relying on numba as an optional dependency. for example, if we release a new model written in numba, it will be only available installing statsforecast[numba]. this approach could help us iterate new models faster and once they are stable, we could migrate them to their compiled version. :)

jmoralez commented 8 months ago

I used the following to profile the current compilation times:

Click to expand ```python import datetime import operator import os os.environ.pop('NIXTLA_NUMBA_CACHE', None) from collections import defaultdict import numpy as np import pandas as pd from numba.core.event import install_recorder from statsforecast.core import StatsForecast from statsforecast.models import * from statsforecast.utils import generate_series data = generate_series(2) models = [ AutoARIMA(season_length=7), AutoCES(season_length=7), AutoETS(season_length=7), AutoTheta(season_length=7), SimpleExponentialSmoothing(alpha=0.1), GARCH(), TBATS(seasonal_periods=7), ] sf = StatsForecast(models=models, freq='D') with install_recorder("numba:compile") as rec: forecast = sf.forecast(df=data, h=7) events = defaultdict(dict) for ts, event in rec.buffer: if event.is_start: stage = 'start' else: stage = 'end' events[event.data['dispatcher']][stage] = ts comp_times_ms = [] for fn, times in events.items(): module = fn.py_func.__module__ if not module.startswith('statsforecast'): continue name = f'{module}.{fn.__name__}' start = datetime.datetime.fromtimestamp(times['start']) end = datetime.datetime.fromtimestamp(times['end']) time_in_ms = round((end - start).microseconds / 1000) comp_times_ms.append((name, time_in_ms)) top_fns = sorted(comp_times_ms, key=operator.itemgetter(1), reverse=True) times_by_module = defaultdict(int) for fn, time in top_fns: times_by_module[fn.split('.')[1]] += time top_modules = sorted(times_by_module.items(), key=operator.itemgetter(1), reverse=True) ```

And got the following results:

Times in milliseconds by function:

Click to expand ``` [('statsforecast.theta.initstate', 936), ('statsforecast.ces.switch_ces', 907), ('statsforecast.theta.initparamtheta', 881), ('statsforecast.arima._make_arima', 849), ('statsforecast.ces.cesfcst', 828), ('statsforecast.ces.ces_target_fn', 823), ('statsforecast.tbats.makeTBATSFMatrix', 793), ('statsforecast.garch.garch_sigma2', 739), ('statsforecast.ets.ets_target_fn', 734), ('statsforecast.theta.pegelsresid_theta', 696), ('statsforecast.ets.pegelsresid_C', 690), ('statsforecast.tbats.calcTBATSFaster', 683), ('statsforecast.ets.nelder_mead_ets', 664), ('statsforecast.theta.theta_target_fn', 605), ('statsforecast.arima.arima_gradtrans', 564), ('statsforecast.ces.initparamces', 557), ('statsforecast.arima.arima_css', 546), ('statsforecast.theta.thetafcst', 513), ('statsforecast.theta.nelder_mead_theta', 503), ('statsforecast.ces.nelder_mead_ces', 490), ('statsforecast.ces.cescalc', 486), ('statsforecast.theta.thetacalc', 483), ('statsforecast.arima.getQ0', 477), ('statsforecast.ets.etscalc', 476), ('statsforecast.ets.etsforecast', 453), ('statsforecast.arima.diff1d', 435), ('statsforecast.theta.switch_theta', 432), ('statsforecast.ces.cesupdate', 379), ('statsforecast.arima.partrans', 370), ('statsforecast.arima.ARIMA_invtrans', 351), ('statsforecast.ets.update', 327), ('statsforecast.ets.restrict_to_bounds', 325), ('statsforecast.garch.garch_loglik', 288), ('statsforecast.models._ses_fcst_mse', 286), ('statsforecast.ets.switch', 274), ('statsforecast.arima.inclu2', 256), ('statsforecast.arima.arima_undopars', 256), ('statsforecast.theta.thetaupdate', 255), ('statsforecast.ets.forecast', 252), ('statsforecast.arima.tsconv', 240), ('statsforecast.ces.cesforecast', 230), ('statsforecast.arima.arima_like', 209), ('statsforecast.theta.thetaforecast', 193), ('statsforecast.arima.invpartrans', 189), ('statsforecast.theta.is_constant', 176), ('statsforecast.arima.arima_transpar', 174), ('statsforecast.ets.is_constant', 161), ('statsforecast.ets.initparam', 151), ('statsforecast.ces.pegelsresid_ces', 128), ('statsforecast.garch.garch_cons', 72), ('statsforecast.arima.kalman_forecast', 22)] ```

Times in milliseconds by module:

[('theta', 5673),
 ('arima', 4938),
 ('ces', 4828),
 ('ets', 4507),
 ('tbats', 1476),
 ('garch', 1099),
 ('models', 286)]

So I believe we can migrate them in that order (I already migrated ETS in #757 because I profiled this wrong xD) but we can continue with Theta next.