Closed spsanderson closed 2 years ago
Function:
#' Boilerplate Workflow #' #' @family Boiler_Plate #' @family XGBoost #' #' @author Steven P. Sanderson II, MPH #' #' @details #' This uses the `parsnip::boost_tree()` with the `engine` set to `xgboost` #' #' @description This is a boilerplate function to create automatically the following: #' - recipe #' - model specification #' - workflow #' - tuned model (grid ect) #' #' @seealso \url{https://parsnip.tidymodels.org/reference/details_boost_tree_xgboost.html} #' #' @param .data The data being passed to the function. The time-series object. #' @param .rec_obj This is the recipe object you want to use. You can use #' `hai_xgboost_data_prepper()` an automatic recipe_object. #' @param .splits_obj NULL is the default, when NULL then one will be created. #' @param .rsamp_obj NULL is the default, when NULL then one will be created. It #' will default to creating an [rsample::mc_cv()] object. #' @param .tune Default is TRUE, this will create a tuning grid and tuned workflow #' @param .grid_size Default is 10 #' @param .num_cores Default is 1 #' @param .best_metric Default is "f_meas". You can choose a metric depending on the #' model_type used. If `regression` then see [healthyR.ai::hai_default_regression_metric_set()], #' if `classification` then see [healthyR.ai::hai_default_classification_metric_set()]. #' @param .model_type Default is `classification`, can also be `regression`. #' #' @examples #' \dontrun{ #' data <- iris #' #' rec_obj <- hai_xgboost_data_prepper(data, Species ~ .) #' #' auto_xgb <- hai_auto_xgboost( #' .data = data, #' .rec_obj = rec_obj, #' .best_metric = "f_meas" #' ) #' #' auto_xgb$recipe_info #' } #' #' @return #' A list #' #' @export #' hai_auto_xgboost <- function(.data, .rec_obj, .splits_obj = NULL, .rsamp_obj = NULL, .tune = TRUE, .grid_size = 10, .num_cores = 1, .best_metric = "f_meas", .model_type = "classification"){ # Tidyeval ---- grid_size <- as.numeric(.grid_size) num_cores <- as.numeric(.num_cores) best_metric <- as.character(.best_metric) data_tbl <- dplyr::as_tibble(.data) splits <- .splits_obj rec_obj <- .rec_obj rsamp_obj <- .rsamp_obj model_type <- as.character(.model_type) # Checks ---- if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){ rlang::abort( message = "'.splits_obj' must have a class of 'rsplit', use the rsample package.", use_cli_format = TRUE ) } if (!inherits(x = rec_obj, what = "recipe")){ rlang::abort( message = "'.rec_obj' must have a class of 'recipe'." ) } if (!model_type %in% c("regression","classification")){ rlang::abort( message = paste0( "You chose a mode of: '", model_type, "' this is unsupported. Choose from either 'regression' or 'classification'." ), use_cli_format = TRUE ) } if (!inherits(x = rsamp_obj, what = "rset") && !is.null(rsamp_obj)){ rlang::abort( message = "The '.rsamp_obj' argument must either be NULL or an object of calss 'rset'.", use_cli_format = TRUE ) } if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){ rlang::abort( message = "The '.splits_obj' argument must either be NULL or an object of class 'rsplit'", use_cli_format = TRUE ) } # Set default metric set ---- if (model_type == "classification"){ ms <- healthyR.ai::hai_default_classification_metric_set() } else { ms <- healthyR.ai::hai_default_regression_metric_set() } # Get splits if not then create if (is.null(splits)){ splits <- rsample::initial_split(data = data_tbl) } else { splits <- splits } # Tune/Spec ---- if (.tune){ # Model Specification model_spec <- parsnip::boost_tree( trees = tune::tune(), min_n = tune::tune(), tree_depth = tune::tune(), learn_rate = tune::tune(), loss_reduction = tune::tune(), sample_size = tune::tune() ) } else { model_spec <- parsnip::boost_tree() } # Model Specification ---- model_spec <- model_spec %>% parsnip::set_mode(mode = model_type) %>% parsnip::set_engine(engine = "xgboost") # Workflow ---- wflw <- workflows::workflow() %>% workflows::add_recipe(rec_obj) %>% workflows::add_model(model_spec) # Tuning Grid --- if (.tune){ # Make tuning grid tuning_grid_spec <- dials::grid_latin_hypercube( hardhat::extract_parameter_set_dials(model_spec), size = grid_size ) # Cross validation object if (is.null(rsamp_obj)){ cv_obj <- rsample::mc_cv( data = rsample::training(splits) ) } else { cv_obj <- rsamp_obj } # Tune the workflow # Start parallel backed modeltime::parallel_start(num_cores) tuned_results <- wflw %>% tune::tune_grid( resamples = cv_obj, grid = tuning_grid_spec, metrics = ms ) modeltime::parallel_stop() # Get the best result set by a specified metric best_result_set <- tuned_results %>% tune::show_best(metric = best_metric, n = 1) # Plot results tune_results_plt <- tuned_results %>% tune::autoplot() + ggplot2::theme_minimal() + ggplot2::geom_smooth(se = FALSE) + ggplot2::theme(legend.position = "bottom") # Make final workflow wflw_fit <- wflw %>% tune::finalize_workflow( tuned_results %>% tune::show_best(metric = best_metric, n = 1) ) %>% parsnip::fit(rsample::training(splits)) } else { wflw_fit <- wflw %>% parsnip::fit(rsample::training(splits)) } # Return ---- output <- list( recipe_info = rec_obj, model_info = list( model_spec = model_spec, wflw = wflw, fitted_wflw = wflw_fit, was_tuned = ifelse(.tune, "tuned", "not_tuned") ) ) if (.tune){ output$tuned_info = list( tuning_grid = tuning_grid_spec, cv_obj = cv_obj, tuned_results = tuned_results, grid_size = grid_size, best_metric = best_metric, best_result_set = best_result_set, tuning_grid_plot = tune_results_plt, plotly_grid_plot = plotly::ggplotly(tune_results_plt) ) } attr(output, "function_type") <- "boilerplate" attr(output, ".grid_size") <- .grid_size attr(output, ".tune") <- .tune attr(output, ".best_metric") <- .best_metric attr(output, ".model_type") <- model_type attr(output, ".engine") <- "xgboost" return(invisible(output)) }
Example:
> output $recipe_info Recipe Inputs: role #variables outcome 1 predictor 4 Operations: Factor variables from tidyselect::vars_select_helpers$where(is.character) Novel factor level assignment for recipes::all_nominal_predictors() Dummy variables from recipes::all_nominal_predictors() Zero variance filter on recipes::all_predictors() $model_info $model_info$model_spec Boosted Tree Model Specification (classification) Main Arguments: trees = tune::tune() min_n = tune::tune() tree_depth = tune::tune() learn_rate = tune::tune() loss_reduction = tune::tune() sample_size = tune::tune() Computational engine: xgboost $model_info$wflw == Workflow =============================================================================== Preprocessor: Recipe Model: boost_tree() -- Preprocessor --------------------------------------------------------------------------- 4 Recipe Steps * step_string2factor() * step_novel() * step_dummy() * step_zv() -- Model ---------------------------------------------------------------------------------- Boosted Tree Model Specification (classification) Main Arguments: trees = tune::tune() min_n = tune::tune() tree_depth = tune::tune() learn_rate = tune::tune() loss_reduction = tune::tune() sample_size = tune::tune() Computational engine: xgboost $model_info$fitted_wflw == Workflow [trained] ===================================================================== Preprocessor: Recipe Model: boost_tree() -- Preprocessor --------------------------------------------------------------------------- 4 Recipe Steps * step_string2factor() * step_novel() * step_dummy() * step_zv() -- Model ---------------------------------------------------------------------------------- ##### xgb.Booster raw: 2 Mb call: xgboost::xgb.train(params = list(eta = 0.00393124784309238, max_depth = 13L, gamma = 1.26613999938156e-10, colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 5L, subsample = 0.567503290886525, objective = "multi:softprob"), data = x$data, nrounds = 893L, watchlist = x$watchlist, verbose = 0, num_class = 3L, nthread = 1) params (as set within xgb.train): eta = "0.00393124784309238", max_depth = "13", gamma = "1.26613999938156e-10", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "5", subsample = "0.567503290886525", objective = "multi:softprob", num_class = "3", nthread = "1", validate_parameters = "TRUE" xgb.attributes: niter callbacks: cb.evaluation.log() # of features: 4 niter: 893 nfeatures : 4 evaluation_log: iter training_mlogloss 1 1.0938820 2 1.0891530 --- 892 0.1821332 893 0.1821265 $model_info$was_tuned [1] "tuned" $tuned_info $tuned_info$tuning_grid # A tibble: 10 x 6 trees min_n tree_depth learn_rate loss_reduction sample_size <int> <int> <int> <dbl> <dbl> <dbl> 1 1565 22 12 0.104 5.23e- 7 0.976 2 893 5 13 0.00393 1.27e-10 0.568 3 310 30 7 0.00964 1.48e- 8 0.907 4 166 27 4 0.0107 4.36e- 8 0.399 5 547 12 15 0.0351 2.21e+ 1 0.510 6 1146 39 5 0.0620 2.69e- 1 0.321 7 1932 14 10 0.287 7.36e- 6 0.239 8 1800 33 9 0.00206 8.36e- 5 0.643 9 1270 6 2 0.0217 2.28e- 2 0.775 10 688 20 6 0.00143 1.31e- 3 0.158 $tuned_info$cv_obj # Monte Carlo cross-validation (0.75/0.25) with 25 resamples # A tibble: 25 x 2 splits id <list> <chr> 1 <split [84/28]> Resample01 2 <split [84/28]> Resample02 3 <split [84/28]> Resample03 4 <split [84/28]> Resample04 5 <split [84/28]> Resample05 6 <split [84/28]> Resample06 7 <split [84/28]> Resample07 8 <split [84/28]> Resample08 9 <split [84/28]> Resample09 10 <split [84/28]> Resample10 # ... with 15 more rows $tuned_info$tuned_results # Tuning results # Monte Carlo cross-validation (0.75/0.25) with 25 resamples # A tibble: 25 x 4 splits id .metrics .notes <list> <chr> <list> <list> 1 <split [84/28]> Resample01 <tibble [110 x 10]> <tibble [1 x 3]> 2 <split [84/28]> Resample02 <tibble [110 x 10]> <tibble [1 x 3]> 3 <split [84/28]> Resample03 <tibble [110 x 10]> <tibble [1 x 3]> 4 <split [84/28]> Resample04 <tibble [110 x 10]> <tibble [1 x 3]> 5 <split [84/28]> Resample05 <tibble [110 x 10]> <tibble [1 x 3]> 6 <split [84/28]> Resample06 <tibble [110 x 10]> <tibble [1 x 3]> 7 <split [84/28]> Resample07 <tibble [110 x 10]> <tibble [1 x 3]> 8 <split [84/28]> Resample08 <tibble [110 x 10]> <tibble [1 x 3]> 9 <split [84/28]> Resample09 <tibble [110 x 10]> <tibble [1 x 3]> 10 <split [84/28]> Resample10 <tibble [110 x 10]> <tibble [1 x 3]> # ... with 15 more rows There were issues with some computations: - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x3: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x2: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve... Use `collect_notes(object)` for more information. $tuned_info$grid_size [1] 10 $tuned_info$best_metric [1] "f_meas" $tuned_info$best_result_set # A tibble: 1 x 12 trees min_n tree_depth learn_rate loss_reduction sample_size .metric .estimator mean <int> <int> <int> <dbl> <dbl> <dbl> <chr> <chr> <dbl> 1 893 5 13 0.00393 1.27e-10 0.568 f_meas macro 0.933 # ... with 3 more variables: n <int>, std_err <dbl>, .config <chr> $tuned_info$tuning_grid_plot `geom_smooth()` using method = 'loess' and formula 'y ~ x' $tuned_info$plotly_grid_plot attr(,"function_type") [1] "boilerplate" attr(,".grid_size") [1] 10 attr(,".tune") [1] TRUE attr(,".best_metric") [1] "f_meas" attr(,".model_type") [1] "classification" attr(,".engine") [1] "xgboost"
Function:
Example: