Closed spsanderson closed 2 years ago
https://usemodels.tidymodels.org/reference/templates.html
Start
library(healthyverse) library(dplyr) library(recipes) library(timetk) library(rsample) library(dials) get_recipe_call <- function(.rec_call){ cl <- .rec_call cl$tune <- NULL cl$verbose <- NULL cl$colors <- NULL cl$prefix <- NULL rec_cl <- cl rec_cl[[1]] <- rlang::expr(recipe) rec_cl } assign_value <- function(name, value, cr = TRUE) { value <- rlang::enexpr(value) value <- rlang::expr_text(value, width = 74L) chr_assign(name, value, cr) } chr_assign <- function(name, value, cr = TRUE) { name <- paste(name, "<-") if (cr) { res <- c(name, paste0("\n ", value)) } else { res <- paste(name, value) } res } ts_auto_glmnet <- function(.data, .date_col, .value_col, .formula, .rsamp_obj, .prefix = "ts_glmnet", .tune = TRUE, .grid_size = 10, .num_cores = 1, .cv_assess = 12, .cv_skip = 3, .cv_slice_limit = 6, .best_metric = "rmse", .bootstrap_final = FALSE){ # Tidyeval ---- date_col_var_expr <- rlang::enquo(.date_col) value_col_var_expr <- rlang::enquo(.value_col) sampling_object <- .rsamp_obj # Cross Validation cv_assess = as.numeric(.cv_assess) cv_skip = as.numeric(.cv_skip) cv_slice = as.numeric(.cv_slice_limit) # Tuning Grid grid_size <- as.numeric(.grid_size) num_cores <- as.numeric(.num_cores) best_metric <- as.character(.best_metric) # Data and splits splits <- .rsamp_obj data_tbl <- dplyr::as_tibble(.data) # Checks ---- if (rlang::quo_is_missing(date_col_var_expr)){ rlang::abort( message = "'.date_col' must be supplied.", use_cli_format = TRUE ) } if (rlang::quo_is_missing(value_col_var_expr)){ rlang::abort( message = "'.value_col' must be supplied.", use_cli_format = TRUE ) } if (!inherits(x = splits, what = "rsplit")){ rlang::abort( message = "'.rsamp_obj' must be have class rsplit, use the rsample package.", use_cli_format = TRUE ) } # Recipe ---- # Get the initial recipe call recipe_call <- get_recipe_call(match.call()) rec_syntax <- paste0(.prefix, "_recipe") %>% assign_value(!!recipe_call) rec_obj <- recipes::recipe(formula = .formula, data = data_tbl) rec_obj <- rec_obj %>% timetk::step_timeseries_signature({{date_col_var_expr}}) %>% timetk::step_holiday_signature({{date_col_var_expr}}) %>% recipes::step_novel(recipes::all_nominal_predictors()) %>% recipes::step_mutate_at(tidyselect::vars_select_helpers$where(is.character) , fn = ~ as.factor(.)) %>% recipes::step_rm({{date_col_var_expr}}) %>% recipes::step_dummy(recipes::all_nominal(), one_hot = TRUE) %>% recipes::step_zv(recipes::all_predictors(), -date_col_index.num) %>% recipes::step_normalize(recipes::all_numeric_predictors(), -date_col_index.num) # Tune/Spec ---- if (.tune){ model_spec <- parsnip::linear_reg( penalty = tune::tune(), mixture = tune::tune(), mode = "regression", engine = "glmnet" ) } else { model_spec <- parsnip::linear_reg( mode = "regression", engine = "glmnet" ) } # Workflow ---- wflw <- workflows::workflow() %>% workflows::add_recipe(rec_obj) %>% workflows::add_model(model_spec) # Tuning Grid ---- if (.tune){ # Start parallel backend modeltime::parallel_start(num_cores) tuning_grid_spec <- tidyr::crossing( penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05,0.2,0.4,0.6,0.8,1) ) %>% dplyr::slice_sample(n = grid_size) # Make TS CV ---- tscv <- timetk::time_series_cv( data = rsample::training(splits), date_var = {{date_col_var_expr}}, cumulative = TRUE, assess = cv_assess, skip = cv_skip, slice_limit = cv_slice ) # Tune the workflow tuned_results <- wflw %>% tune::tune_grid( resamples = tscv, grid = tuning_grid_spec, metrics = modeltime::default_forecast_accuracy_metric_set() ) # Get the best result set by a specified metric best_result_set <- tuned_results %>% tune::show_best(metric = best_metric, n = 1) # Plot results tune_results_plt <- tuned_results %>% tune::autoplot() + ggplot2::theme_minimal() + ggplot2::geom_smooth(se = FALSE) # Make final workflow wflw_fit <- wflw %>% tune::finalize_workflow( tuned_results %>% tune::show_best(metric = best_metric, n = Inf) %>% dplyr::slice(1) ) %>% parsnip::fit(rsample::training(splits)) # Stop parallel backend modeltime::parallel_stop() } else { wflw_fit <- wflw %>% parsnip::fit(rsample::training(splits)) } # Calibrate and Plot ---- cap <- healthyR.ts::calibrate_and_plot( wflw_fit, .splits_obj = splits, .data = data_tbl, .interactive = TRUE, .print_info = FALSE ) # Return ---- output <- list( recipe_info = list( recipe_call = recipe_call, recipe_syntax = rec_syntax, rec_obj = rec_obj ), model_info = list( model_spec = model_spec, wflw = wflw, fitted_wflw = wflw_fit, was_tuned = ifelse(.tune, "tuned", "not_tuned") ), model_calibration = list( plot = cap$plot, calibration_tbl = cap$calibration_tbl, model_accuracy = cap$model_accuracy ) ) if (.tune){ output$tuned_info = list( tuning_grid = tuning_grid_spec, tscv = tscv, tuned_results = tuned_results, grid_size = grid_size, best_metric = best_metric, best_result_set = best_result_set, tuning_grid_plot = tune_results_plt, plotly_grid_plot = plotly::ggplotly(tune_results_plt) ) } return(invisible(output)) }
Example
data <- AirPassengers %>% ts_to_tbl() %>% select(-index) splits <- time_series_split( data , date_col , assess = 12 , skip = 3 , cumulative = TRUE ) tst $recipe_info $recipe_info$recipe_call recipe(.data = data, .date_col = date_col, .value_col = value, .formula = value ~ ., .rsamp_obj = splits, .num_cores = 2) $recipe_info$recipe_syntax [1] "ts_glmnet_recipe <-" [2] "\n recipe(.data = data, .date_col = date_col, .value_col = value, .formula = value ~ \n ., .rsamp_obj = splits, .num_cores = 2)" $recipe_info$rec_obj Recipe Inputs: role #variables outcome 1 predictor 1 Operations: Timeseries signature features from date_col Holiday signature features from date_col Novel factor level assignment for recipes::all_nominal_predictors() Variable mutation for tidyselect::vars_select_helpers$where(is.character) Variables removed date_col Dummy variables from recipes::all_nominal() Zero variance filter on recipes::all_predictors(), -date_col_index.num Centering and scaling for recipes::all_numeric_predictors(), -date_col_index.num $model_info $model_info$model_spec Linear Regression Model Specification (regression) Main Arguments: penalty = tune::tune() mixture = tune::tune() Computational engine: glmnet $model_info$wflw == Workflow =============================================================================== Preprocessor: Recipe Model: linear_reg() -- Preprocessor --------------------------------------------------------------------------- 8 Recipe Steps * step_timeseries_signature() * step_holiday_signature() * step_novel() * step_mutate_at() * step_rm() * step_dummy() * step_zv() * step_normalize() -- Model ---------------------------------------------------------------------------------- Linear Regression Model Specification (regression) Main Arguments: penalty = tune::tune() mixture = tune::tune() Computational engine: glmnet $model_info$fitted_wflw == Workflow [trained] ===================================================================== Preprocessor: Recipe Model: linear_reg() -- Preprocessor --------------------------------------------------------------------------- 8 Recipe Steps * step_timeseries_signature() * step_holiday_signature() * step_novel() * step_mutate_at() * step_rm() * step_dummy() * step_zv() * step_normalize() -- Model ---------------------------------------------------------------------------------- Call: glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian", alpha = ~0.6) Df %Dev Lambda 1 0 0.00 162.800 2 3 12.10 148.300 3 3 22.69 135.200 4 3 31.85 123.200 5 3 39.75 112.200 6 3 46.54 102.200 7 3 52.36 93.160 8 3 57.33 84.890 9 3 61.57 77.350 10 3 65.18 70.480 11 3 68.25 64.210 12 3 70.85 58.510 13 3 73.05 53.310 14 3 74.91 48.580 15 4 76.77 44.260 16 5 79.11 40.330 17 5 81.24 36.750 18 5 83.03 33.480 19 5 84.54 30.510 20 6 85.85 27.800 21 6 87.05 25.330 22 7 88.13 23.080 23 9 89.08 21.030 24 9 89.96 19.160 25 10 90.71 17.460 26 10 91.40 15.910 27 10 91.98 14.490 28 10 92.47 13.210 29 11 92.90 12.030 30 11 93.27 10.960 31 11 93.58 9.990 32 11 93.84 9.102 33 11 94.06 8.294 34 11 94.24 7.557 35 11 94.39 6.886 36 14 94.55 6.274 37 14 94.71 5.716 38 14 94.84 5.209 39 15 94.96 4.746 40 17 95.07 4.324 41 17 95.16 3.940 42 22 95.24 3.590 43 22 95.30 3.271 44 22 95.36 2.981 45 24 95.41 2.716 46 25 95.45 2.475 ... and 30 more lines. $model_info$was_tuned [1] "tuned" $model_calibration $model_calibration$plot $model_calibration$calibration_tbl # Modeltime Table # A tibble: 1 x 5 .model_id .model .model_desc .type .calibration_data <int> <list> <chr> <chr> <list> 1 1 <workflow> GLMNET Test <tibble [12 x 4]> $model_calibration$model_accuracy # A tibble: 1 x 9 .model_id .model_desc .type mae mape mase smape rmse rsq <int> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> 1 1 GLMNET Test 39.4 7.49 0.815 7.90 52.1 0.914 $tuned_info $tuned_info$tuning_grid # A tibble: 10 x 2 penalty mixture <dbl> <dbl> 1 0.0162 0.4 2 0.00886 1 3 0.0000113 0.4 4 0.00000336 1 5 0.0162 0.6 6 0.00000183 0.6 7 0.0298 0.4 8 0.0000113 0.6 9 0.00000616 0.4 10 0.1 0.2 $tuned_info$tscv # Time Series Cross Validation Plan # A tibble: 6 x 2 splits id <list> <chr> 1 <split [120/12]> Slice1 2 <split [117/12]> Slice2 3 <split [114/12]> Slice3 4 <split [111/12]> Slice4 5 <split [108/12]> Slice5 6 <split [105/12]> Slice6 $tuned_info$tuned_results # Tuning results # NA # A tibble: 6 x 4 splits id .metrics .notes <list> <chr> <list> <list> 1 <split [120/12]> Slice1 <tibble [60 x 6]> <tibble [0 x 3]> 2 <split [117/12]> Slice2 <tibble [60 x 6]> <tibble [0 x 3]> 3 <split [114/12]> Slice3 <tibble [60 x 6]> <tibble [0 x 3]> 4 <split [111/12]> Slice4 <tibble [60 x 6]> <tibble [0 x 3]> 5 <split [108/12]> Slice5 <tibble [60 x 6]> <tibble [0 x 3]> 6 <split [105/12]> Slice6 <tibble [60 x 6]> <tibble [0 x 3]> $tuned_info$grid_size [1] 10 $tuned_info$best_metric [1] "rmse" $tuned_info$best_result_set # A tibble: 1 x 8 penalty mixture .metric .estimator mean n std_err .config <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr> 1 0.00000183 0.6 rmse standard 38.0 6 2.02 Preprocessor1_Model05 $tuned_info$tuning_grid_plot `geom_smooth()` using method = 'loess' and formula 'y ~ x' $tuned_info$plotly_grid_plot There were 26 warnings (use warnings() to see them)
Change function name to ts_auto_glmnet
https://usemodels.tidymodels.org/reference/templates.html
Start
Example