spsanderson / healthyR.ai

healthyR.ai - AI package for the healthyverse
http://www.spsanderson.com/healthyR.ai/
Other
16 stars 6 forks source link

hai_auto_xgboost() #276

Closed spsanderson closed 2 years ago

spsanderson commented 2 years ago

Function:

#' Boilerplate Workflow
#'
#' @family Boiler_Plate
#' @family XGBoost
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @details
#' This uses the `parsnip::boost_tree()` with the `engine` set to `xgboost`
#'
#' @description This is a boilerplate function to create automatically the following:
#' -  recipe
#' -  model specification
#' -  workflow
#' -  tuned model (grid ect)
#' 
#' @seealso \url{https://parsnip.tidymodels.org/reference/details_boost_tree_xgboost.html}
#'
#' @param .data The data being passed to the function. The time-series object.
#' @param .rec_obj This is the recipe object you want to use. You can use
#' `hai_xgboost_data_prepper()` an automatic recipe_object.
#' @param .splits_obj NULL is the default, when NULL then one will be created.
#' @param .rsamp_obj NULL is the default, when NULL then one will be created. It
#' will default to creating an [rsample::mc_cv()] object.
#' @param .tune Default is TRUE, this will create a tuning grid and tuned workflow
#' @param .grid_size Default is 10
#' @param .num_cores Default is 1
#' @param .best_metric Default is "f_meas". You can choose a metric depending on the
#' model_type used. If `regression` then see [healthyR.ai::hai_default_regression_metric_set()],
#' if `classification` then see [healthyR.ai::hai_default_classification_metric_set()].
#' @param .model_type Default is `classification`, can also be `regression`.
#'
#' @examples
#' \dontrun{
#' data <- iris
#'
#' rec_obj <- hai_xgboost_data_prepper(data, Species ~ .)
#'
#' auto_xgb <- hai_auto_xgboost(
#'   .data = data,
#'   .rec_obj = rec_obj,
#'   .best_metric = "f_meas"
#' )
#'
#' auto_xgb$recipe_info
#' }
#'
#' @return
#' A list
#'
#' @export
#'

hai_auto_xgboost <- function(.data, .rec_obj, .splits_obj = NULL, .rsamp_obj = NULL,
                         .tune = TRUE, .grid_size = 10, .num_cores = 1,
                         .best_metric = "f_meas", .model_type = "classification"){

  # Tidyeval ----
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)

  data_tbl <- dplyr::as_tibble(.data)

  splits <- .splits_obj
  rec_obj <- .rec_obj
  rsamp_obj <- .rsamp_obj
  model_type <- as.character(.model_type)

  # Checks ----
  if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){
    rlang::abort(
      message = "'.splits_obj' must have a class of 'rsplit', use the rsample package.",
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = rec_obj, what = "recipe")){
    rlang::abort(
      message = "'.rec_obj' must have a class of 'recipe'."
    )
  }

  if (!model_type %in% c("regression","classification")){
    rlang::abort(
      message = paste0(
        "You chose a mode of: '",
        model_type,
        "' this is unsupported. Choose from either 'regression' or 'classification'."
      ),
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = rsamp_obj, what = "rset") && !is.null(rsamp_obj)){
    rlang::abort(
      message = "The '.rsamp_obj' argument must either be NULL or an object of
      calss 'rset'.",
      use_cli_format = TRUE
    )
  }

  if (!inherits(x = splits, what = "rsplit") && !is.null(splits)){
    rlang::abort(
      message = "The '.splits_obj' argument must either be NULL or an object of
      class 'rsplit'",
      use_cli_format = TRUE
    )
  }

  # Set default metric set ----
  if (model_type == "classification"){
    ms <- healthyR.ai::hai_default_classification_metric_set()
  } else {
    ms <- healthyR.ai::hai_default_regression_metric_set()
  }

  # Get splits if not then create
  if (is.null(splits)){
    splits <- rsample::initial_split(data = data_tbl)
  } else {
    splits <- splits
  }

  # Tune/Spec ----
  if (.tune){
    # Model Specification
    model_spec <- parsnip::boost_tree(
      trees = tune::tune(), 
      min_n = tune::tune(), 
      tree_depth = tune::tune(), 
      learn_rate = tune::tune(), 
      loss_reduction = tune::tune(), 
      sample_size = tune::tune()
    )
  } else {
    model_spec <- parsnip::boost_tree()
  }

  # Model Specification ----
  model_spec <- model_spec %>%
    parsnip::set_mode(mode = model_type) %>%
    parsnip::set_engine(engine = "xgboost")

  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec)

  # Tuning Grid ---
  if (.tune){

    # Make tuning grid
    tuning_grid_spec <- dials::grid_latin_hypercube(
      hardhat::extract_parameter_set_dials(model_spec),
      size = grid_size
    )

    # Cross validation object
    if (is.null(rsamp_obj)){
      cv_obj <- rsample::mc_cv(
        data = rsample::training(splits)
      )
    } else {
      cv_obj <- rsamp_obj
    }

    # Tune the workflow
    # Start parallel backed
    modeltime::parallel_start(num_cores)

    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = cv_obj,
        grid      = tuning_grid_spec,
        metrics   = ms
      )

    modeltime::parallel_stop()

    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)

    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() +
      ggplot2::geom_smooth(se = FALSE) +
      ggplot2::theme(legend.position = "bottom")

    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = 1)
      ) %>%
      parsnip::fit(rsample::training(splits))

  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }

  # Return ----
  output <- list(
    recipe_info = rec_obj,
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    )
  )

  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      cv_obj           = cv_obj,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }

  attr(output, "function_type") <- "boilerplate"
  attr(output, ".grid_size") <- .grid_size
  attr(output, ".tune") <- .tune
  attr(output, ".best_metric") <- .best_metric
  attr(output, ".model_type") <- model_type
  attr(output, ".engine") <- "xgboost"

  return(invisible(output))

}

Example:

> output
$recipe_info
Recipe

Inputs:

      role #variables
   outcome          1
 predictor          4

Operations:

Factor variables from tidyselect::vars_select_helpers$where(is.character)
Novel factor level assignment for recipes::all_nominal_predictors()
Dummy variables from recipes::all_nominal_predictors()
Zero variance filter on recipes::all_predictors()

$model_info
$model_info$model_spec
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = tune::tune()
  min_n = tune::tune()
  tree_depth = tune::tune()
  learn_rate = tune::tune()
  loss_reduction = tune::tune()
  sample_size = tune::tune()

Computational engine: xgboost 

$model_info$wflw
== Workflow ===============================================================================
Preprocessor: Recipe
Model: boost_tree()

-- Preprocessor ---------------------------------------------------------------------------
4 Recipe Steps

* step_string2factor()
* step_novel()
* step_dummy()
* step_zv()

-- Model ----------------------------------------------------------------------------------
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = tune::tune()
  min_n = tune::tune()
  tree_depth = tune::tune()
  learn_rate = tune::tune()
  loss_reduction = tune::tune()
  sample_size = tune::tune()

Computational engine: xgboost 

$model_info$fitted_wflw
== Workflow [trained] =====================================================================
Preprocessor: Recipe
Model: boost_tree()

-- Preprocessor ---------------------------------------------------------------------------
4 Recipe Steps

* step_string2factor()
* step_novel()
* step_dummy()
* step_zv()

-- Model ----------------------------------------------------------------------------------
##### xgb.Booster
raw: 2 Mb 
call:
  xgboost::xgb.train(params = list(eta = 0.00393124784309238, max_depth = 13L, 
    gamma = 1.26613999938156e-10, colsample_bytree = 1, colsample_bynode = 1, 
    min_child_weight = 5L, subsample = 0.567503290886525, objective = "multi:softprob"), 
    data = x$data, nrounds = 893L, watchlist = x$watchlist, verbose = 0, 
    num_class = 3L, nthread = 1)
params (as set within xgb.train):
  eta = "0.00393124784309238", max_depth = "13", gamma = "1.26613999938156e-10", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "5", subsample = "0.567503290886525", objective = "multi:softprob", num_class = "3", nthread = "1", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 4 
niter: 893
nfeatures : 4 
evaluation_log:
    iter training_mlogloss
       1         1.0938820
       2         1.0891530
---                       
     892         0.1821332
     893         0.1821265

$model_info$was_tuned
[1] "tuned"

$tuned_info
$tuned_info$tuning_grid
# A tibble: 10 x 6
   trees min_n tree_depth learn_rate loss_reduction sample_size
   <int> <int>      <int>      <dbl>          <dbl>       <dbl>
 1  1565    22         12    0.104         5.23e- 7       0.976
 2   893     5         13    0.00393       1.27e-10       0.568
 3   310    30          7    0.00964       1.48e- 8       0.907
 4   166    27          4    0.0107        4.36e- 8       0.399
 5   547    12         15    0.0351        2.21e+ 1       0.510
 6  1146    39          5    0.0620        2.69e- 1       0.321
 7  1932    14         10    0.287         7.36e- 6       0.239
 8  1800    33          9    0.00206       8.36e- 5       0.643
 9  1270     6          2    0.0217        2.28e- 2       0.775
10   688    20          6    0.00143       1.31e- 3       0.158

$tuned_info$cv_obj
# Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
# A tibble: 25 x 2
   splits          id        
   <list>          <chr>     
 1 <split [84/28]> Resample01
 2 <split [84/28]> Resample02
 3 <split [84/28]> Resample03
 4 <split [84/28]> Resample04
 5 <split [84/28]> Resample05
 6 <split [84/28]> Resample06
 7 <split [84/28]> Resample07
 8 <split [84/28]> Resample08
 9 <split [84/28]> Resample09
10 <split [84/28]> Resample10
# ... with 15 more rows

$tuned_info$tuned_results
# Tuning results
# Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
# A tibble: 25 x 4
   splits          id         .metrics            .notes          
   <list>          <chr>      <list>              <list>          
 1 <split [84/28]> Resample01 <tibble [110 x 10]> <tibble [1 x 3]>
 2 <split [84/28]> Resample02 <tibble [110 x 10]> <tibble [1 x 3]>
 3 <split [84/28]> Resample03 <tibble [110 x 10]> <tibble [1 x 3]>
 4 <split [84/28]> Resample04 <tibble [110 x 10]> <tibble [1 x 3]>
 5 <split [84/28]> Resample05 <tibble [110 x 10]> <tibble [1 x 3]>
 6 <split [84/28]> Resample06 <tibble [110 x 10]> <tibble [1 x 3]>
 7 <split [84/28]> Resample07 <tibble [110 x 10]> <tibble [1 x 3]>
 8 <split [84/28]> Resample08 <tibble [110 x 10]> <tibble [1 x 3]>
 9 <split [84/28]> Resample09 <tibble [110 x 10]> <tibble [1 x 3]>
10 <split [84/28]> Resample10 <tibble [110 x 10]> <tibble [1 x 3]>
# ... with 15 more rows

There were issues with some computations:

  - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x3: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x2: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...   - Warning(s) x1: While computing multiclass `precision()`, some levels had no predicted eve...

Use `collect_notes(object)` for more information.

$tuned_info$grid_size
[1] 10

$tuned_info$best_metric
[1] "f_meas"

$tuned_info$best_result_set
# A tibble: 1 x 12
  trees min_n tree_depth learn_rate loss_reduction sample_size .metric .estimator  mean
  <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>   <chr>      <dbl>
1   893     5         13    0.00393       1.27e-10       0.568 f_meas  macro      0.933
# ... with 3 more variables: n <int>, std_err <dbl>, .config <chr>

$tuned_info$tuning_grid_plot
`geom_smooth()` using method = 'loess' and formula 'y ~ x'

$tuned_info$plotly_grid_plot

attr(,"function_type")
[1] "boilerplate"
attr(,".grid_size")
[1] 10
attr(,".tune")
[1] TRUE
attr(,".best_metric")
[1] "f_meas"
attr(,".model_type")
[1] "classification"
attr(,".engine")
[1] "xgboost"