spsanderson / healthyR.ts

A time-series companion package to healthyR
https://www.spsanderson.com/healthyR.ts/
Other
19 stars 3 forks source link

lm #280

Closed spsanderson closed 2 years ago

spsanderson commented 2 years ago

Function:

#' Boilerplate Workflow
#'
#' @family Boiler_Plate
#' @family lm
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details This uses `parsnip::linear_reg()` and sets the `engine` to `lm`
#'
#' @seealso \url{https://parsnip.tidymodels.org/reference/linear_reg.html}
#'
#' @description This is a boilerplate function to create automatically the following:
#' -  recipe
#' -  model specification
#' -  workflow
#' -  calibration tibble and plot
#'
#' @param .data The data being passed to the function. The time-series object.
#' @param .date_col The column that holds the datetime.
#' @param .value_col The column that has the value
#' @param .formula The formula that is passed to the recipe like `value ~ .`
#' @param .rsamp_obj The rsample splits object
#' @param .prefix Default is `ts_glmnet`
#' @param .cv_assess How many observations for assess. See [timetk::time_series_cv()]
#' @param .cv_skip How many observations to skip. See [timetk::time_series_cv()]
#' @param .cv_slice_limit How many slices to return. See [timetk::time_series_cv()]
#' @param .best_metric Default is "rmse". See [modeltime::default_forecast_accuracy_metric_set()]
#' @param .bootstrap_final Not yet implemented.
#'
#' @examples
#' \dontrun{
#' library(dplyr)
#'
#' data <- AirPassengers %>%
#'   ts_to_tbl() %>%
#'   select(-index)
#'
#' splits <- time_series_split(
#'   data
#'   , date_col
#'   , assess = 12
#'   , skip = 3
#'   , cumulative = TRUE
#' )
#'
#' ts_lm <- ts_auto_lm(
#'   .data = data,
#'   .date_col = date_col,
#'   .value_col = value,
#'   .rsamp_obj = splits,
#'   .formula = value ~ .,
#' )
#'
#' ts_lm$recipe_info
#' }
#'
#' @return
#' A list
#'
#' @export
#'

ts_auto_lm <- function(.data, .date_col, .value_col, .formula, .rsamp_obj,
                           .prefix = "ts_lm", .cv_assess = 12, .cv_skip = 3,
                           .cv_slice_limit = 6, .best_metric = "rmse",
                           .bootstrap_final = FALSE){

  # Tidyeval ----
  date_col_var_expr <- rlang::enquo(.date_col)
  value_col_var_expr <- rlang::enquo(.value_col)
  sampling_object <- .rsamp_obj

  # Cross Validation
  cv_assess = as.numeric(.cv_assess)
  cv_skip   = as.numeric(.cv_skip)
  cv_slice  = as.numeric(.cv_slice_limit)

  # Data and splits
  splits <- .rsamp_obj
  data_tbl <- dplyr::as_tibble(.data)

  # Checks ----
  if (rlang::quo_is_missing(date_col_var_expr)){
    rlang::abort(
      message = "'.date_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (rlang::quo_is_missing(value_col_var_expr)){
    rlang::abort(
      message = "'.value_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must be have class rsplit, use the rsample package.",
      use_cli_format = TRUE
    )
  }

  # Recipe ----
  # Get the initial recipe call
  recipe_call <- get_recipe_call(match.call())

  rec_syntax <- paste0(.prefix, "_recipe") %>%
    assign_value(!!recipe_call)

  rec_obj <- recipes::recipe(formula = .formula, data = data_tbl)

  rec_obj <- rec_obj %>%
    timetk::step_timeseries_signature({{date_col_var_expr}}) %>%
    timetk::step_holiday_signature({{date_col_var_expr}}) %>%
    recipes::step_novel(recipes::all_nominal_predictors()) %>%
    recipes::step_mutate_at(tidyselect::vars_select_helpers$where(is.character)
                            , fn = ~ as.factor(.)) %>%
    recipes::step_mutate({{date_col_var_expr}} := as.numeric({{date_col_var_expr}})) %>%
    recipes::step_rm({{date_col_var_expr}}) %>%
    recipes::step_dummy(recipes::all_nominal(), one_hot = TRUE) %>%
    recipes::step_nzv(recipes::all_predictors(), -date_col_index.num) %>%
    recipes::step_normalize(recipes::all_numeric_predictors(), -date_col_index.num) %>%
    recipes::step_corr(recipes::all_numeric_predictors())

  # Model Specification ----
  model_spec <- parsnip::linear_reg(
      mode   = "regression",
      engine = "lm"
    )

  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec) 

  wflw_fit <- wflw %>%
    parsnip::fit(rsample::training(splits))

  # Calibrate and Plot ----
  cap <- healthyR.ts::calibrate_and_plot(
    wflw_fit,
    .splits_obj  = splits,
    .data        = data_tbl,
    .interactive = TRUE,
    .print_info = FALSE
  )

  # Return ----
  output <- list(
    recipe_info = list(
      recipe_call   = recipe_call,
      recipe_syntax = rec_syntax,
      rec_obj       = rec_obj
    ),
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = "not_tuned"
    ),
    model_calibration = list(
      plot = cap$plot,
      calibration_tbl = cap$calibration_tbl,
      model_accuracy = cap$model_accuracy
    )
  )

  return(invisible(output))
}

Example:

> ts_lm
$recipe_info
$recipe_info$recipe_call
recipe(.data = data, .date_col = date_col, .value_col = value, 
    .formula = value ~ ., .rsamp_obj = splits)

$recipe_info$recipe_syntax
[1] "ts_lm_recipe <-"                                                                                                    
[2] "\n  recipe(.data = data, .date_col = date_col, .value_col = value, .formula = value ~ \n    ., .rsamp_obj = splits)"

$recipe_info$rec_obj
Recipe

Inputs:

      role #variables
   outcome          1
 predictor          1

Operations:

Timeseries signature features from date_col
Holiday signature features from date_col
Novel factor level assignment for recipes::all_nominal_predictors()
Variable mutation for tidyselect::vars_select_helpers$where(is.character)
Variable mutation for as.numeric(^date_col)
Variables removed date_col
Dummy variables from recipes::all_nominal()
Sparse, unbalanced variable filter on recipes::all_predictors(), -date_col_index.num
Centering and scaling for recipes::all_numeric_predictors(), -date_col_index.num
Correlation filter on recipes::all_numeric_predictors()

$model_info
$model_info$model_spec
Linear Regression Model Specification (regression)

Computational engine: lm 

$model_info$wflw
== Workflow ===============================================================================
Preprocessor: Recipe
Model: linear_reg()

-- Preprocessor ---------------------------------------------------------------------------
10 Recipe Steps

* step_timeseries_signature()
* step_holiday_signature()
* step_novel()
* step_mutate_at()
* step_mutate()
* step_rm()
* step_dummy()
* step_nzv()
* step_normalize()
* step_corr()

-- Model ----------------------------------------------------------------------------------
Linear Regression Model Specification (regression)

Computational engine: lm 

$model_info$fitted_wflw
== Workflow [trained] =====================================================================
Preprocessor: Recipe
Model: linear_reg()

-- Preprocessor ---------------------------------------------------------------------------
10 Recipe Steps

* step_timeseries_signature()
* step_holiday_signature()
* step_novel()
* step_mutate_at()
* step_mutate()
* step_rm()
* step_dummy()
* step_nzv()
* step_normalize()
* step_corr()

-- Model ----------------------------------------------------------------------------------

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
          (Intercept)          date_col_year          date_col_half  
             262.4924                97.9066               220.7605  
     date_col_quarter      date_col_wday.xts          date_col_qday  
            -329.4337                -5.5716               -49.3639  
       date_col_mweek      date_col_week.iso         date_col_week2  
              -1.4767                -0.9463               -35.5370  
       date_col_week3         date_col_week4     date_col_exch_NYSE  
             -36.3834               -29.8878                -3.2820  
   date_col_exch_NERC      date_col_exch_TSX   date_col_exch_ZURICH  
               3.5721                 0.1303                -2.4830  
   date_col_locale_US  date_col_locale_World     date_col_locale_CA  
              -2.5847                 3.7221                -3.2770  
date_col_month.lbl_01  date_col_month.lbl_02  date_col_month.lbl_03  
            -120.3818               -92.7923               -93.3405  
date_col_month.lbl_04  date_col_month.lbl_05  date_col_month.lbl_06  
             -33.1109               -41.3499                     NA  
date_col_month.lbl_07  date_col_month.lbl_08  date_col_month.lbl_09  
             -47.9036               -17.1078                     NA  
date_col_month.lbl_10  date_col_month.lbl_11  date_col_month.lbl_12  
             -19.6599                     NA                     NA  
  date_col_wday.lbl_1    date_col_wday.lbl_2    date_col_wday.lbl_3  
              -7.2139                -3.6770                -2.9149  
  date_col_wday.lbl_4    date_col_wday.lbl_5    date_col_wday.lbl_6  
              -2.4463                -1.4975                     NA  
  date_col_wday.lbl_7  
                   NA  

$model_info$was_tuned
[1] "not_tuned"

$model_calibration
$model_calibration$plot

$model_calibration$calibration_tbl
# Modeltime Table
# A tibble: 1 x 5
  .model_id .model     .model_desc .type .calibration_data
      <int> <list>     <chr>       <chr> <list>           
1         1 <workflow> LM          Test  <tibble [12 x 4]>

$model_calibration$model_accuracy
# A tibble: 1 x 9
  .model_id .model_desc .type   mae  mape  mase smape  rmse   rsq
      <int> <chr>       <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1         1 LM          Test   38.3  7.26 0.793  7.65  51.4 0.932

image