tidyverts / fable

Tidy time series forecasting
https://fable.tidyverts.org
GNU General Public License v3.0
559 stars 65 forks source link

`forecast.NNETAR()` much slower than methods for other model types #171

Closed mpjashby closed 5 years ago

mpjashby commented 5 years ago

I'm comparing forecasts for the various models included in fable. forecast.NNETAR() appears to be much slower than for the other methods. For example, running the code below on a 2.8GHz MacBook Pro with 16GB of memory, the execution time for forecast() on the neural network model is about 7.5 minutes, while the other methods all take less than one second.

I suspect (although not based on much) this may be because the prediction intervals for the neural network have to be simulated. If this is the case, might it be useful to add an argument to forecast() to return only the point forecast, rather than the forecast and the distribution? This would be useful for other situations in which users only required the point forecast.

library(fable)
library(fasster)
library(lubridate)
library(rbenchmark)
library(tsibble)
library(tidyverse)

test_data <- tsibbledata::PBS %>% 
  filter(
    Concession == "General", Type == "Safety net", 
    ATC2 %in% c("A01", "A02", "A03", "A04", "A05"),
    between(Month, ymd("2005-01-01"), ymd("2007-12-31"))
  ) %>% 
  update_tsibble(key = c("ATC2"))

test_models <- list(
  tslm = model(test_data, TSLM(Scripts ~ trend() + season())),
  stl = model(test_data, decomposition_model(STL, Scripts ~ trend() + season(), 
                                             NAIVE(season_adjust))),
  ets = model(test_data, ETS(Scripts ~ trend() + season() + error())),
  arima = model(test_data, ARIMA(Scripts ~ trend() + season())),
  var = model(test_data, VAR(Scripts ~ trend() + season() + AR())),
  neural = model(test_data, NNETAR(Scripts ~ trend() + season() + AR())),
  fasster = model(test_data, FASSTER(crimes ~ poly(1) + trig(12) + ARMA()))
)

test_forecast_times <- map_dfr(
  test_models, 
  ~ benchmark(forecast(., h = "3 years"), replications = 1), 
  .id = "model"
) %>% 
  select(model, execution_time = elapsed)
mitchelloharawild commented 5 years ago

You should be able to only produce point forecasts for models that require simulated forecast intervals by setting times=0. More details in https://fable.tidyverts.org/reference/forecast.NNETAR.html

On Thu., 20 Jun. 2019, 7:53 pm Matt Ashby, notifications@github.com wrote:

I'm comparing forecasts for the various models included in fable. forecast.NNETAR() appears to be much slower than for the other methods. For example, running the code below on a 2.8GHz MacBook Pro with 16GB of memory, the execution time for forecast() on the neural network model is about 7.5 minutes, while the other methods all take less than one second.

I suspect (although not based on much) this may be because the prediction intervals for the neural network have to be simulated. If this is the case, might it be useful to add an argument to forecast() to return only the point forecast, rather than the forecast and the distribution? This would be useful for other situations in which users only required the point forecast.

library(fable) library(fasster) library(lubridate) library(rbenchmark) library(tsibble) library(tidyverse)

test_data <- tsibbledata::PBS %>% filter( Concession == "General", Type == "Safety net", ATC2 %in% c("A01", "A02", "A03", "A04", "A05"), between(Month, ymd("2005-01-01"), ymd("2007-12-31")) ) %>% update_tsibble(key = c("ATC2"))

test_models <- list( tslm = model(test_data, TSLM(Scripts ~ trend() + season())), stl = model(test_data, decomposition_model(STL, Scripts ~ trend() + season(), NAIVE(season_adjust))), ets = model(test_data, ETS(Scripts ~ trend() + season() + error())), arima = model(test_data, ARIMA(Scripts ~ trend() + season())), var = model(test_data, VAR(Scripts ~ trend() + season() + AR())), neural = model(test_data, NNETAR(Scripts ~ trend() + season() + AR())), fasster = model(test_data, FASSTER(crimes ~ poly(1) + trig(12) + ARMA())) )

test_forecast_times <- map_dfr( test_models, ~ benchmark(forecast(., h = "3 years"), replications = 1), .id = "model" ) %>% select(model, execution_time = elapsed)

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tidyverts/fable/issues/171?email_source=notifications&email_token=AD3BJF6YWRMU57VB5CYZM2DP3OYXBA5CNFSM4HZXRROKYY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4G2YHAIA, or mute the thread https://github.com/notifications/unsubscribe-auth/AD3BJF32LWTCC7OFQRTP7OTP3OYXBANCNFSM4HZXRROA .

mpjashby commented 5 years ago

Thanks! The neural network forecast now takes about the same time as the others. It also generates a warning that the forecasts could not be bias adjusted because the standard deviation is unknown, which makes sense.

Might it be useful to add a line on the man page for forecast.NNETAR() to this effect? Since the man page says the times argument is only relevant when bootstrap = TRUE and that the default is bootstrap = FALSE I didn't think that times would have any effect unless bootstrap = TRUE was set.

mitchelloharawild commented 5 years ago

Sounds reasonable, thanks for the suggestion. Glad that the solution works for you.

On Thu., 20 Jun. 2019, 8:23 pm Matt Ashby, notifications@github.com wrote:

Thanks! The neural network forecast now takes about the same time as the others. It also generates a warning that the forecasts could not be bias adjusted because the standard deviation is unknown, which makes sense.

Might it be useful to add a line on the man page for forecast.NNETAR() to this effect? Since the man page says the times argument is only relevant when bootstrap = TRUE and that the default is bootstrap = FALSE I didn't think that times would have any effect unless bootstrap = TRUE was set.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tidyverts/fable/issues/171?email_source=notifications&email_token=AD3BJFZBAWFRTFBGB6YRA43P3O4ILA5CNFSM4HZXRROKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODYGCFXA#issuecomment-504111836, or mute the thread https://github.com/notifications/unsubscribe-auth/AD3BJFZTL44BO4DJDX46E4TP3O4ILANCNFSM4HZXRROA .

mpjashby commented 5 years ago

Thanks for the ultra-fast reply!

mitchelloharawild commented 5 years ago

Hopefully it's a bit better now.