tidymodels / workflowsets

Create a collection of modeling workflows
https://workflowsets.tidymodels.org/
Other
92 stars 10 forks source link

feature request: add finalization method for workflow sets #164

Open jkylearmstrong opened 2 days ago

jkylearmstrong commented 2 days ago

I think it would be helpful to add a function to finalize all the workflows, so far something like this seems to work pretty well

#' Finalize Parameter Grid for a Single Workflow
#'
#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for \code{finalize}
#' @param ... not used
#'
#' @return updated workflow set with parameter grid added to the workflow

.finalize_workflow_set <- function(workflow_id, workflow_sets, data, ...) {
  param_set <- workflow_sets %>%
    workflowsets::extract_parameter_set_dials(id = workflow_id)

  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      dials::finalize(data)

    workflow_sets_with_finalized_params <- workflow_sets %>%
      workflowsets::option_add(param_info = finalized_param_set, id = workflow_id) %>%
      dplyr::filter(wflow_id == workflow_id)

    return(workflow_sets_with_finalized_params)
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}
#' Finalize the workflowset
#'
#' @param x workflow set
#' @param data data frame to use for finalization
#' @param ... additional arguments passed to \code{.finalize_workflow_set}
#'
#' @return updated workflow set with parameter grid added to each workflow

finalize_workflow_set <- function(x, data, ...){
  purrr::map(
    purrr::set_names(x$wflow_id),
    \(z) {
      .finalize_workflow_set(workflow_id = z, workflow_sets = x, data = data, ...)
    }
      ) %>%
    purrr::list_rbind()
}
simonpcouch commented 2 days ago

Related to https://github.com/tidymodels/workflowsets/issues/45.

What would you like to do with the list of finalized models? The use cases I imagine here all involve some sort of model selection, which should ideally be carried out using resampled models.

i.e. the flow we recommend is:

1) Resample models 2a) Choose one workflow configuration to finalize 3a) Fit it on the entire training set

Rather than:

1) Resample models 2b) Choose a configuration of each workflow to finalize 3b) Presumably fit all of them on the entire training set 4b) ...

As there's no way to evaluate the list of models resulting from 3b) as the test set can only be used once.

jkylearmstrong commented 1 day ago
library('tidymodels')
library('workflowsets')
tidymodels_prefer()
data(parabolic)
parabolic <- parabolic 
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#>  $ X1   : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#>  $ X2   : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#>  $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ...
set.seed(1)
split <- initial_split(parabolic)

#train_set <- training(split)
#test_set <- testing(split)
rec <- recipe(class ~ ., data = training(split)) %>%
  step_interact(terms = ~ X1:X2) 
bake(prep(rec), new_data = training(split))
#> # A tibble: 375 × 4
#>         X1     X2 class  X1_x_X2
#>      <dbl>  <dbl> <fct>    <dbl>
#>  1  1.17    0.627 Class2  0.733 
#>  2 -0.769  -1.29  Class2  0.993 
#>  3  1.17    1.02  Class1  1.20  
#>  4  0.510  -2.10  Class2 -1.07  
#>  5  1.38   -0.974 Class2 -1.34  
#>  6 -0.0549 -1.77  Class2  0.0972
#>  7  0.703   1.24  Class1  0.868 
#>  8  1.50    0.418 Class2  0.625 
#>  9 -0.219  -3.08  Class2  0.675 
#> 10  0.606  -0.960 Class2 -0.582 
#> # ℹ 365 more rows
rec_norm <- rec %>%
  step_normalize(all_numeric_predictors())
rec_pca <- rec_norm %>%
  step_pca(all_numeric_predictors(), 
           num_comp = tune()
           )

library('embed')

rec_umap <- rec_norm %>%
  step_umap(all_numeric_predictors(), 
            outcome = "class",
            num_comp = tune(),
            neighbors = tune(),
            min_dist = tune()
            )
library('discrim')

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

xgboost_spec <-
  boost_tree(
    mtry = tune(),
    trees = tune(),
    min_n = tune(),
    tree_depth = tune(),
    learn_rate = tune(),
    loss_reduction = tune(),
    sample_size = tune(),
    stop_iter = tune()
  ) %>%
  set_engine("xgboost") %>%
  set_mode("classification")
set.seed(2)
folds <- vfold_cv(training(split), v = 5)
all_workflows <- 
  workflow_set(
    preproc = list(rec = rec,
                   rec_norm = rec_norm,
                   rec_pca = rec_pca,
                   rec_umap = rec_umap),
    models = list(regularized = reg_disc_sepc, 
                  mars = mars_disc_spec, 
                  cart = cart_spec,
                  xgboost_spec = xgboost_spec)
  )

@simonpcouch see error below:

all_workflows_res <- 
  all_workflows %>% 
  workflow_map(resamples = folds, 
               verbose = TRUE,
               control = control_grid(
                 save_pred = TRUE,
                 parallel_over = "everything",
                 save_workflow = TRUE)
               )
#> i  1 of 16 tuning:     rec_regularized
#> ✔  1 of 16 tuning:     rec_regularized (9.8s)
#> i  2 of 16 tuning:     rec_mars
#> ✔  2 of 16 tuning:     rec_mars (1s)
#> i  3 of 16 tuning:     rec_cart
#> ✔  3 of 16 tuning:     rec_cart (3.1s)
#> i  4 of 16 tuning:     rec_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔  4 of 16 tuning:     rec_xgboost_spec (26.2s)
#> i  5 of 16 tuning:     rec_norm_regularized
#> ✔  5 of 16 tuning:     rec_norm_regularized (10.1s)
#> i  6 of 16 tuning:     rec_norm_mars
#> ✔  6 of 16 tuning:     rec_norm_mars (821ms)
#> i  7 of 16 tuning:     rec_norm_cart
#> ✔  7 of 16 tuning:     rec_norm_cart (3.5s)
#> i  8 of 16 tuning:     rec_norm_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔  8 of 16 tuning:     rec_norm_xgboost_spec (26.6s)
#> i  9 of 16 tuning:     rec_pca_regularized
#> ✔  9 of 16 tuning:     rec_pca_regularized (10.3s)
#> i 10 of 16 tuning:     rec_pca_mars
#> ✔ 10 of 16 tuning:     rec_pca_mars (3.5s)
#> i 11 of 16 tuning:     rec_pca_cart
#> ✔ 11 of 16 tuning:     rec_pca_cart (5s)
#> i 12 of 16 tuning:     rec_pca_xgboost_spec
#> ✖ 12 of 16 tuning:     rec_pca_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data,  :   Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.
#> i 13 of 16 tuning:     rec_umap_regularized
#> ✔ 13 of 16 tuning:     rec_umap_regularized (1m 32.6s)
#> i 14 of 16 tuning:     rec_umap_mars
#> ✔ 14 of 16 tuning:     rec_umap_mars (1m 28.6s)
#> i 15 of 16 tuning:     rec_umap_cart
#> ✔ 15 of 16 tuning:     rec_umap_cart (1m 27s)
#> i 16 of 16 tuning:     rec_umap_xgboost_spec
#> ✖ 16 of 16 tuning:     rec_umap_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data,  :   Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.
all_workflows_res %>%
  autoplot()
all_workflows_res
#> # A workflow set/tibble: 16 × 4
#>    wflow_id              info             option    result        
#>    <chr>                 <list>           <list>    <list>        
#>  1 rec_regularized       <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  2 rec_mars              <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  3 rec_cart              <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  4 rec_xgboost_spec      <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  5 rec_norm_regularized  <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  6 rec_norm_mars         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  7 rec_norm_cart         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  8 rec_norm_xgboost_spec <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  9 rec_pca_regularized   <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 10 rec_pca_mars          <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 11 rec_pca_cart          <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 12 rec_pca_xgboost_spec  <tibble [1 × 4]> <opts[2]> <try-errr [1]>
#> 13 rec_umap_regularized  <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 14 rec_umap_mars         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 15 rec_umap_cart         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 16 rec_umap_xgboost_spec <tibble [1 × 4]> <opts[2]> <try-errr [1]>

@simonpcouch - The workflows with the errors are the models with contain tuning parameters for both PCA / UMAP as well as xgboost / random forest - basically when the preprocessor needs to determine num_comp the number of components which is the new number of columns which goes into the mtry tuning parameter.

# filter to only workflows that have results 

all_workflows_res <- all_workflows_res %>%
  filter(! map_lgl(result, ~ inherits(.x, "try-error"))) %>%
  filter(! map_lgl(result, ~ identical(.x, list())))

@simonpcouch - on a slightly different topic, one issue with auto-plot is that it defaults to using the data stored in 'info' column:

all_workflows_res %>%
  autoplot(metric = "accuracy") 

@simonpcouch - so we have to do a remapping into the info column to see the different preprocessor/model combinations

all_workflows_res <- all_workflows_res %>%
  mutate(info = map2(info, wflow_id, function(info, wflow_id) {
    info %>%
      mutate(preproc = 
               case_when(
                 stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
                 stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
                 stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
                 TRUE ~ "recipe"
                 )
           )
  }))

all_workflows_res %>%
  autoplot(metric = "accuracy") +
  facet_grid(~preprocessor) +
  theme(legend.position = "bottom") + 
  guides(
    color = guide_legend(ncol = 2),
    shape = guide_legend(ncol = 2)
         )

here are the functions to address the unfinalized workflows and create a new workflow set:

#' Finalize Parameter Grid for a Single Workflow

#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for
#' @param … not used

#'  @return updated workflow set with parameter grid added to the workflow

.finalize_workflow_set <- function(workflow_id, workflow_sets, data, ...) {
  param_set <- workflow_sets %>%
    workflowsets::extract_parameter_set_dials(id = workflow_id)

  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      dials::finalize(data)

    workflow_sets_with_finalized_params <- workflow_sets %>%
      workflowsets::option_add(param_info = finalized_param_set, id = workflow_id) %>%
      dplyr::filter(wflow_id == workflow_id)

    return(workflow_sets_with_finalized_params)
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}
#' Finalize the workflowset

#' @param x workflow set
#' @param data data frame to use for finalization
#' @param … additional arguments passed to

#' @return updated workflow set with parameter grid added to each workflow

finalize_workflow_set <- function(x, data, ...){
  purrr::map(
    purrr::set_names(x$wflow_id),
    \(z) {
      .finalize_workflow_set(workflow_id = z, workflow_sets = x, data = data, ...)
    }
      ) %>%
    purrr::list_rbind()
}
all_workflows <- all_workflows %>%
  finalize_workflow_set(training(split))

Now we can run over all the model combinations:

all_workflows_res <- 
  all_workflows %>% 
  workflow_map(resamples = folds, 
               verbose = TRUE,
               control = control_grid(
                 save_pred = TRUE,
                 parallel_over = "everything",
                 save_workflow = TRUE)
               )
#> i  1 of 16 tuning:     rec_regularized
#> ✔  1 of 16 tuning:     rec_regularized (9.8s)
#> i  2 of 16 tuning:     rec_mars
#> ✔  2 of 16 tuning:     rec_mars (710ms)
#> i  3 of 16 tuning:     rec_cart
#> ✔  3 of 16 tuning:     rec_cart (3.1s)
#> i  4 of 16 tuning:     rec_xgboost_spec
#> ✔  4 of 16 tuning:     rec_xgboost_spec (34.7s)
#> i  5 of 16 tuning:     rec_norm_regularized
#> ✔  5 of 16 tuning:     rec_norm_regularized (10.7s)
#> i  6 of 16 tuning:     rec_norm_mars
#> ✔  6 of 16 tuning:     rec_norm_mars (938ms)
#> i  7 of 16 tuning:     rec_norm_cart
#> ✔  7 of 16 tuning:     rec_norm_cart (3.7s)
#> i  8 of 16 tuning:     rec_norm_xgboost_spec
#> ✔  8 of 16 tuning:     rec_norm_xgboost_spec (30.5s)
#> i  9 of 16 tuning:     rec_pca_regularized
#> ✔  9 of 16 tuning:     rec_pca_regularized (10.9s)
#> i 10 of 16 tuning:     rec_pca_mars
#> ✔ 10 of 16 tuning:     rec_pca_mars (4s)
#> i 11 of 16 tuning:     rec_pca_cart
#> ✔ 11 of 16 tuning:     rec_pca_cart (5s)
#> i 12 of 16 tuning:     rec_pca_xgboost_spec
#> ✔ 12 of 16 tuning:     rec_pca_xgboost_spec (36.4s)
#> i 13 of 16 tuning:     rec_umap_regularized
#> ✔ 13 of 16 tuning:     rec_umap_regularized (1m 35.8s)
#> i 14 of 16 tuning:     rec_umap_mars
#> ✔ 14 of 16 tuning:     rec_umap_mars (1m 31.9s)
#> i 15 of 16 tuning:     rec_umap_cart
#> ✔ 15 of 16 tuning:     rec_umap_cart (1m 31.6s)
#> i 16 of 16 tuning:     rec_umap_xgboost_spec
#> ✔ 16 of 16 tuning:     rec_umap_xgboost_spec (1m 53.4s)
# we have to do a remapping into the `info` column to address:

all_workflows_res <- all_workflows_res %>%
  mutate(info = map2(info, wflow_id, function(info, wflow_id) {
    info %>%
      mutate(preproc = 
               case_when(
                 stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
                 stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
                 stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
                 TRUE ~ "recipe"
                 )
           )
  }))

all_workflows_res %>%
  autoplot(metric = "accuracy") +
  facet_grid(~preprocessor) +
  theme(legend.position = "bottom") + 
  guides(
    color = guide_legend(ncol = 2),
    shape = guide_legend(ncol = 2)
         )

sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 20.04.6 LTS
#> 
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so;  LAPACK version 3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
#>  [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
#>  [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
#> [10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   
#> 
#> time zone: UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] uwot_0.2.2         Matrix_1.7-0       xgboost_1.7.8.1    rpart_4.1.23      
#>  [5] earth_5.3.4        plotmo_3.6.4       plotrix_3.8-4      Formula_1.2-5     
#>  [9] mda_0.5-4          class_7.3-22       klaR_1.7-3         MASS_7.3-60.2     
#> [13] discrim_1.0.1      embed_1.1.4        yardstick_1.3.1    workflowsets_1.1.0
#> [17] workflows_1.1.4    tune_1.2.1         tidyr_1.3.1        tibble_3.2.1      
#> [21] rsample_1.2.1      recipes_1.1.0      purrr_1.0.2        parsnip_1.2.1     
#> [25] modeldata_1.4.0    infer_1.0.7        ggplot2_3.5.1      dplyr_1.1.4       
#> [29] dials_1.3.0        scales_1.3.0       broom_1.0.7        tidymodels_1.2.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] conflicted_1.2.0    rlang_1.1.4         magrittr_2.0.3     
#>  [4] furrr_0.3.1         RcppAnnoy_0.0.22    compiler_4.4.1     
#>  [7] vctrs_0.6.5         combinat_0.0-8      stringr_1.5.1      
#> [10] lhs_1.2.0           pkgconfig_2.0.3     fastmap_1.2.0      
#> [13] backports_1.5.0     labeling_0.4.3      utf8_1.2.4         
#> [16] promises_1.3.0      rmarkdown_2.28      prodlim_2024.06.25 
#> [19] haven_2.5.4         xfun_0.48           reprex_2.1.1       
#> [22] cachem_1.1.0        labelled_2.13.0     jsonlite_1.8.9     
#> [25] highr_0.11          later_1.3.2         irlba_2.3.5.1      
#> [28] parallel_4.4.1      prettyunits_1.2.0   R6_2.5.1           
#> [31] stringi_1.8.4       parallelly_1.38.0   lubridate_1.9.3    
#> [34] Rcpp_1.0.13         iterators_1.0.14    knitr_1.48         
#> [37] future.apply_1.11.2 httpuv_1.6.15       splines_4.4.1      
#> [40] nnet_7.3-19         timechange_0.3.0    tidyselect_1.2.1   
#> [43] rstudioapi_0.16.0   yaml_2.3.10         timeDate_4041.110  
#> [46] codetools_0.2-20    miniUI_0.1.1.1      curl_5.2.3         
#> [49] listenv_0.9.1       lattice_0.22-6      shiny_1.9.1        
#> [52] withr_3.0.1         evaluate_1.0.0      future_1.34.0      
#> [55] survival_3.6-4      xml2_1.3.6          pillar_1.9.0       
#> [58] foreach_1.5.2       generics_0.1.3      hms_1.1.3          
#> [61] munsell_0.5.1       globals_0.16.3      xtable_1.8-4       
#> [64] glue_1.8.0          tools_4.4.1         data.table_1.16.0  
#> [67] gower_1.0.1         forcats_1.0.0       fs_1.6.4           
#> [70] grid_4.4.1          ipred_0.9-15        colorspace_2.1-1   
#> [73] cli_3.6.3           DiceDesign_1.10     fansi_1.0.6        
#> [76] lava_1.8.0          gtable_0.3.5        GPfit_1.0-8        
#> [79] digest_0.6.37       farver_2.1.2        memoise_2.0.1      
#> [82] htmltools_0.5.8.1   questionr_0.7.8     lifecycle_1.0.4    
#> [85] hardhat_1.4.0       mime_0.12

Created on 2024-10-08 with reprex v2.1.1