spsanderson / healthyR.ts

A time-series companion package to healthyR
https://www.spsanderson.com/healthyR.ts/
Other
19 stars 3 forks source link

Automatic TS Recipe, Model, Tuning Grid and Workflow creation `glmnet` #241

Closed spsanderson closed 2 years ago

spsanderson commented 2 years ago

https://usemodels.tidymodels.org/reference/templates.html

Start

library(healthyverse)
library(dplyr)
library(recipes)
library(timetk)
library(rsample)
library(dials)

get_recipe_call <- function(.rec_call){
  cl <- .rec_call
  cl$tune <- NULL
  cl$verbose <- NULL
  cl$colors <- NULL
  cl$prefix <- NULL
  rec_cl <- cl
  rec_cl[[1]] <- rlang::expr(recipe)
  rec_cl
}

assign_value <- function(name, value, cr = TRUE) {
  value <- rlang::enexpr(value)
  value <- rlang::expr_text(value, width = 74L)
  chr_assign(name, value, cr)
}

chr_assign <- function(name, value, cr = TRUE) {
  name <- paste(name, "<-")
  if (cr) {
    res <- c(name, paste0("\n  ", value))
  } else {
    res <- paste(name, value)
  }
  res
}

ts_auto_glmnet <- function(.data, .date_col, .value_col, .formula, .rsamp_obj,
                           .prefix = "ts_glmnet", .tune = TRUE, .grid_size = 10,
                           .num_cores = 1, .cv_assess = 12, .cv_skip = 3, 
                           .cv_slice_limit = 6, .best_metric = "rmse", 
                           .bootstrap_final = FALSE){

  # Tidyeval ----
  date_col_var_expr <- rlang::enquo(.date_col)
  value_col_var_expr <- rlang::enquo(.value_col)
  sampling_object <- .rsamp_obj
  # Cross Validation
  cv_assess = as.numeric(.cv_assess)
  cv_skip   = as.numeric(.cv_skip)
  cv_slice  = as.numeric(.cv_slice_limit)
  # Tuning Grid
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)
  # Data and splits
  splits <- .rsamp_obj
  data_tbl <- dplyr::as_tibble(.data)

  # Checks ----
  if (rlang::quo_is_missing(date_col_var_expr)){
    rlang::abort(
      message = "'.date_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (rlang::quo_is_missing(value_col_var_expr)){
    rlang::abort(
      message = "'.value_col' must be supplied.",
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must be have class rsplit, use the rsample package.",
      use_cli_format = TRUE
    )
  }

  # Recipe ----
  # Get the initial recipe call
  recipe_call <- get_recipe_call(match.call())

  rec_syntax <- paste0(.prefix, "_recipe") %>%
    assign_value(!!recipe_call)

  rec_obj <- recipes::recipe(formula = .formula, data = data_tbl)

  rec_obj <- rec_obj %>%
    timetk::step_timeseries_signature({{date_col_var_expr}}) %>%
    timetk::step_holiday_signature({{date_col_var_expr}}) %>%
    recipes::step_novel(recipes::all_nominal_predictors()) %>%
    recipes::step_mutate_at(tidyselect::vars_select_helpers$where(is.character)
                            , fn = ~ as.factor(.)) %>%
    recipes::step_rm({{date_col_var_expr}}) %>%
    recipes::step_dummy(recipes::all_nominal(), one_hot = TRUE) %>%
    recipes::step_zv(recipes::all_predictors(), -date_col_index.num) %>%
    recipes::step_normalize(recipes::all_numeric_predictors(), -date_col_index.num)

  # Tune/Spec ----
  if (.tune){
    model_spec <- parsnip::linear_reg(
      penalty = tune::tune(),
      mixture = tune::tune(),
      mode    = "regression",
      engine  = "glmnet"
    )
  } else {
    model_spec <- parsnip::linear_reg(
      mode   = "regression",
      engine = "glmnet"
    )
  }

  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec) 

  # Tuning Grid ----
  if (.tune){

    # Start parallel backend
    modeltime::parallel_start(num_cores)

    tuning_grid_spec <- tidyr::crossing(
      penalty = 10^seq(-6, -1, length.out = 20),
      mixture = c(0.05,0.2,0.4,0.6,0.8,1)
    ) %>%
      dplyr::slice_sample(n = grid_size)

    # Make TS CV ----
    tscv <- timetk::time_series_cv(
      data        = rsample::training(splits),
      date_var    = {{date_col_var_expr}},
      cumulative  = TRUE,
      assess      = cv_assess,
      skip        = cv_skip,
      slice_limit = cv_slice
    )

    # Tune the workflow
    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = tscv,
        grid      = tuning_grid_spec,
        metrics   = modeltime::default_forecast_accuracy_metric_set()
      )

    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)

    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() + 
      ggplot2::geom_smooth(se = FALSE)

    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = Inf) %>%
          dplyr::slice(1)
      ) %>%
      parsnip::fit(rsample::training(splits))

    # Stop parallel backend
    modeltime::parallel_stop()

  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }

  # Calibrate and Plot ----
  cap <- healthyR.ts::calibrate_and_plot(
    wflw_fit,
    .splits_obj  = splits,
    .data        = data_tbl,
    .interactive = TRUE,
    .print_info = FALSE
  )

  # Return ----
  output <- list(
    recipe_info = list(
      recipe_call   = recipe_call,
      recipe_syntax = rec_syntax,
      rec_obj       = rec_obj
    ),
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    ),
    model_calibration = list(
      plot = cap$plot,
      calibration_tbl = cap$calibration_tbl,
      model_accuracy = cap$model_accuracy
    )
  )

  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      tscv             = tscv,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }

  return(invisible(output))
}

Example

data <- AirPassengers %>%
  ts_to_tbl() %>%
  select(-index)

splits <- time_series_split(
  data
  , date_col
  , assess = 12
  , skip = 3
  , cumulative = TRUE
)

tst
$recipe_info
$recipe_info$recipe_call
recipe(.data = data, .date_col = date_col, .value_col = value, 
    .formula = value ~ ., .rsamp_obj = splits, .num_cores = 2)

$recipe_info$recipe_syntax
[1] "ts_glmnet_recipe <-"                                                                                                                
[2] "\n  recipe(.data = data, .date_col = date_col, .value_col = value, .formula = value ~ \n    ., .rsamp_obj = splits, .num_cores = 2)"

$recipe_info$rec_obj
Recipe

Inputs:

      role #variables
   outcome          1
 predictor          1

Operations:

Timeseries signature features from date_col
Holiday signature features from date_col
Novel factor level assignment for recipes::all_nominal_predictors()
Variable mutation for tidyselect::vars_select_helpers$where(is.character)
Variables removed date_col
Dummy variables from recipes::all_nominal()
Zero variance filter on recipes::all_predictors(), -date_col_index.num
Centering and scaling for recipes::all_numeric_predictors(), -date_col_index.num

$model_info
$model_info$model_spec
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune::tune()
  mixture = tune::tune()

Computational engine: glmnet 

$model_info$wflw
== Workflow ===============================================================================
Preprocessor: Recipe
Model: linear_reg()

-- Preprocessor ---------------------------------------------------------------------------
8 Recipe Steps

* step_timeseries_signature()
* step_holiday_signature()
* step_novel()
* step_mutate_at()
* step_rm()
* step_dummy()
* step_zv()
* step_normalize()

-- Model ----------------------------------------------------------------------------------
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune::tune()
  mixture = tune::tune()

Computational engine: glmnet 

$model_info$fitted_wflw
== Workflow [trained] =====================================================================
Preprocessor: Recipe
Model: linear_reg()

-- Preprocessor ---------------------------------------------------------------------------
8 Recipe Steps

* step_timeseries_signature()
* step_holiday_signature()
* step_novel()
* step_mutate_at()
* step_rm()
* step_dummy()
* step_zv()
* step_normalize()

-- Model ----------------------------------------------------------------------------------

Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian",      alpha = ~0.6) 

   Df  %Dev  Lambda
1   0  0.00 162.800
2   3 12.10 148.300
3   3 22.69 135.200
4   3 31.85 123.200
5   3 39.75 112.200
6   3 46.54 102.200
7   3 52.36  93.160
8   3 57.33  84.890
9   3 61.57  77.350
10  3 65.18  70.480
11  3 68.25  64.210
12  3 70.85  58.510
13  3 73.05  53.310
14  3 74.91  48.580
15  4 76.77  44.260
16  5 79.11  40.330
17  5 81.24  36.750
18  5 83.03  33.480
19  5 84.54  30.510
20  6 85.85  27.800
21  6 87.05  25.330
22  7 88.13  23.080
23  9 89.08  21.030
24  9 89.96  19.160
25 10 90.71  17.460
26 10 91.40  15.910
27 10 91.98  14.490
28 10 92.47  13.210
29 11 92.90  12.030
30 11 93.27  10.960
31 11 93.58   9.990
32 11 93.84   9.102
33 11 94.06   8.294
34 11 94.24   7.557
35 11 94.39   6.886
36 14 94.55   6.274
37 14 94.71   5.716
38 14 94.84   5.209
39 15 94.96   4.746
40 17 95.07   4.324
41 17 95.16   3.940
42 22 95.24   3.590
43 22 95.30   3.271
44 22 95.36   2.981
45 24 95.41   2.716
46 25 95.45   2.475

...
and 30 more lines.

$model_info$was_tuned
[1] "tuned"

$model_calibration
$model_calibration$plot

$model_calibration$calibration_tbl
# Modeltime Table
# A tibble: 1 x 5
  .model_id .model     .model_desc .type .calibration_data
      <int> <list>     <chr>       <chr> <list>           
1         1 <workflow> GLMNET      Test  <tibble [12 x 4]>

$model_calibration$model_accuracy
# A tibble: 1 x 9
  .model_id .model_desc .type   mae  mape  mase smape  rmse   rsq
      <int> <chr>       <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1         1 GLMNET      Test   39.4  7.49 0.815  7.90  52.1 0.914

$tuned_info
$tuned_info$tuning_grid
# A tibble: 10 x 2
      penalty mixture
        <dbl>   <dbl>
 1 0.0162         0.4
 2 0.00886        1  
 3 0.0000113      0.4
 4 0.00000336     1  
 5 0.0162         0.6
 6 0.00000183     0.6
 7 0.0298         0.4
 8 0.0000113      0.6
 9 0.00000616     0.4
10 0.1            0.2

$tuned_info$tscv
# Time Series Cross Validation Plan 
# A tibble: 6 x 2
  splits           id    
  <list>           <chr> 
1 <split [120/12]> Slice1
2 <split [117/12]> Slice2
3 <split [114/12]> Slice3
4 <split [111/12]> Slice4
5 <split [108/12]> Slice5
6 <split [105/12]> Slice6

$tuned_info$tuned_results
# Tuning results
# NA 
# A tibble: 6 x 4
  splits           id     .metrics          .notes          
  <list>           <chr>  <list>            <list>          
1 <split [120/12]> Slice1 <tibble [60 x 6]> <tibble [0 x 3]>
2 <split [117/12]> Slice2 <tibble [60 x 6]> <tibble [0 x 3]>
3 <split [114/12]> Slice3 <tibble [60 x 6]> <tibble [0 x 3]>
4 <split [111/12]> Slice4 <tibble [60 x 6]> <tibble [0 x 3]>
5 <split [108/12]> Slice5 <tibble [60 x 6]> <tibble [0 x 3]>
6 <split [105/12]> Slice6 <tibble [60 x 6]> <tibble [0 x 3]>

$tuned_info$grid_size
[1] 10

$tuned_info$best_metric
[1] "rmse"

$tuned_info$best_result_set
# A tibble: 1 x 8
     penalty mixture .metric .estimator  mean     n std_err .config              
       <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00000183     0.6 rmse    standard    38.0     6    2.02 Preprocessor1_Model05

$tuned_info$tuning_grid_plot
`geom_smooth()` using method = 'loess' and formula 'y ~ x'

$tuned_info$plotly_grid_plot

There were 26 warnings (use warnings() to see them)

image

spsanderson commented 2 years ago

Change function name to ts_auto_glmnet