spsanderson / healthyR.ts

A time-series companion package to healthyR
https://www.spsanderson.com/healthyR.ts/
Other
18 stars 3 forks source link

Update `ts_auto_arima()` to use "auto_arima" for model spec when `.tune = FALSE` #489

Closed spsanderson closed 11 months ago

spsanderson commented 11 months ago

Update the model spec to use the engine of "auto_arima" when the .tune parameter is FALSE

Function:

ts_auto_arima <- function(.data, .date_col, .value_col, .formula, .rsamp_obj,
                          .prefix = "ts_arima", .tune = TRUE, .grid_size = 10,
                          .num_cores = 1, .cv_assess = 12, .cv_skip = 3,
                          .cv_slice_limit = 6, .best_metric = "rmse",
                          .bootstrap_final = FALSE){

  # Tidyeval ----
  date_col_var_expr <- rlang::enquo(.date_col)
  value_col_var_expr <- rlang::enquo(.value_col)
  sampling_object <- .rsamp_obj
  # Cross Validation
  cv_assess = as.numeric(.cv_assess)
  cv_skip   = as.numeric(.cv_skip)
  cv_slice  = as.numeric(.cv_slice_limit)
  # Tuning Grid
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)
  # Data and splits
  splits <- .rsamp_obj
  data_tbl <- dplyr::as_tibble(.data)

  # Checks ----
  if (rlang::quo_is_missing(date_col_var_expr)){
    rlang::abort(
      message = "'.date_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (rlang::quo_is_missing(value_col_var_expr)){
    rlang::abort(
      message = "'.value_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must be have class rsplit, use the rsample package.",
      use_cli_format = TRUE
    )
  }

  # Recipe ----
  # Get the initial recipe call
  recipe_call <- get_recipe_call(match.call())

  rec_syntax <- paste0(.prefix, "_recipe") %>%
    assign_value(!!recipe_call)

  rec_obj <- recipes::recipe(formula = .formula, data = data_tbl)

  # Tune/Spec ----
  if (.tune){
    model_spec <- modeltime::arima_reg(
      seasonal_period            = tune::tune()
      , non_seasonal_ar          = tune::tune()
      , non_seasonal_differences = tune::tune()
      , non_seasonal_ma          = tune::tune()
      , seasonal_ar              = tune::tune()
      , seasonal_differences     = tune::tune()
      , seasonal_ma              = tune::tune()
    ) %>%
      parsnip::set_mode(mode = "regression") %>%
      parsnip::set_engine("arima")
  } else {
    model_spec <- modeltime::arima_reg() %>%
      parsnip::set_mode(mode = "regression") %>%
      parsnip::set_engine("auto_arima")
  }

  # if (.tune == TRUE){
  #   model_spec <- model_spec %>%
  #     parsnip::set_mode(mode = "regression") %>%
  #     parsnip::set_engine("arima")
  # } else {
  #   model_spec <- model_spec %>%
  #     parsnip::set_mode(mode = "regression") %>%
  #     parsnip::set_engine("auto_arima")
  # }

  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec)

  # Tuning Grid ----
  if (.tune){

    # Start parallel backend
    modeltime::parallel_start(num_cores)

    tuning_grid_spec <- dials::grid_latin_hypercube(
      hardhat::extract_parameter_set_dials(model_spec),
      size = grid_size
    )

    # Make TS CV ----
    tscv <- timetk::time_series_cv(
      data        = rsample::training(splits),
      date_var    = {{date_col_var_expr}},
      #date_var = date_col,
      cumulative  = TRUE,
      assess      = cv_assess,
      skip        = cv_skip,
      slice_limit = cv_slice
    )

    # Tune the workflow
    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = tscv,
        grid      = tuning_grid_spec,
        metrics   = modeltime::default_forecast_accuracy_metric_set()
      )

    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)

    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() +
      ggplot2::geom_smooth(se = FALSE)

    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = 1)
      ) %>%
      parsnip::fit(rsample::training(splits))

    # Stop parallel backend
    modeltime::parallel_stop()

  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }

  # Calibrate and Plot ----
  cap <- healthyR.ts::calibrate_and_plot(
    wflw_fit,
    .splits_obj  = splits,
    .data        = data_tbl,
    .interactive = TRUE,
    .print_info  = FALSE
  )

  # Return ----
  output <- list(
    recipe_info = list(
      recipe_call   = recipe_call,
      recipe_syntax = rec_syntax,
      rec_obj       = rec_obj
    ),
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    ),
    model_calibration = list(
      plot = cap$plot,
      calibration_tbl = cap$calibration_tbl,
      model_accuracy = cap$model_accuracy
    )
  )

  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      tscv             = tscv,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }

  # Add attributes
  attr(output, ".tune") <- .tune
  attr(output, ".grid_size") <- .grid_size
  attr(output, ".cv_assess") <- .cv_assess
  attr(output, ".cv_skip") <- .cv_skip
  attr(output, ".cv_slice_limit") <- .cv_slice_limit
  attr(output, ".best_metric") <- .best_metric
  attr(output, ".bootstrap_final") <- .bootstrap_final
  attr(output, ".mode") <- "regression"
  attr(output, ".parsnip_engine") <- ifelse(.tune, "arima", "auto_arima")
  attr(output, ".function_family") <- "boilerplate"

  return(invisible(output))
}

Examples:

library(dplyr)
library(timetk)
library(modeltime)
library(healthyR.ts)

data <- AirPassengers %>%
 ts_to_tbl() %>%
 select(-index)

splits <- time_series_split(
 data
 , date_col
 , assess = 12
 , skip = 3
 , cumulative = TRUE
)

tsa <- ts_auto_arima(
 .data = data,
 .num_cores = 2,
 .date_col = date_col,
 .value_col = value,
 .rsamp_obj = splits,
 .formula = value ~ .,
 .grid_size = 5,
 .cv_slice_limit = 2,
 .tune = TRUE
)
tsa$model_info$model_spec
ARIMA Regression Model Specification (regression)

Main Arguments:
  seasonal_period = tune::tune()
  non_seasonal_ar = tune::tune()
  non_seasonal_differences = tune::tune()
  non_seasonal_ma = tune::tune()
  seasonal_ar = tune::tune()
  seasonal_differences = tune::tune()
  seasonal_ma = tune::tune()

Computational engine: arima 

tsa <- ts_auto_arima(
 .data = data,
 .num_cores = 2,
 .date_col = date_col,
 .value_col = value,
 .rsamp_obj = splits,
 .formula = value ~ .,
 .grid_size = 5,
 .cv_slice_limit = 2,
 .tune = FALSE
)
frequency = 12 observations per 1 year

ARIMA Regression Model Specification (regression)

Computational engine: auto_arima