tidymodels / workflowsets

Create a collection of modeling workflows
https://workflowsets.tidymodels.org/
Other
92 stars 10 forks source link

`num_comp` not updated with `option_add()` for `discrim_flexible` workflows #157

Open marioem opened 3 months ago

marioem commented 3 months ago

The problem

The usual method of finalizing tunable parameters in a workflow set is not working with num_comp for discrim_flexible workflows

Reproducible example

library(tidymodels)
library(tidyverse)
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
library(AppliedPredictiveModeling)
library(future)

set.seed(321)
data <- quadBoundaryFunc(2000) %>% select(A = X1, B = X2, class)

data_splits1 <- initial_split(data, prop = .85, strata = class)

train_datas1 <- training(data_splits1)
test_datas1  <- testing(data_splits1)

foldss1 <- vfold_cv(train_datas1, v = 5, repeats = 3, strata = class)

biv_rec <- 
  recipe(class ~ ., data = train_datas1) %>%
  step_normalize(all_predictors())

discrim_flexible_spec <-
  discrim_flexible(num_terms =  tune::tune(), 
                   prod_degree =  tune::tune(), 
                   prune_method =  tune::tune()) %>%
  set_engine('earth') %>%
  set_mode('classification')

normalizeds1 <- 
  workflow_set(
    preproc = list(norm = biv_rec),
    models = list(FD = discrim_flexible_spec)
  )

normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials()
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[?]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]
#> 
#> Model parameters needing finalization:
#>    # Model Terms ('num_terms')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.
pars <- normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials() %>% finalize(x = train_datas1 %>% select(-class))

pars
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[+]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]

normalizeds1 <- normalizeds1 %>% 
  option_add(param_info = pars, id = "norm_FD")

normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials()
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[?]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]
#> 
#> Model parameters needing finalization:
#>    # Model Terms ('num_terms')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.
# 'num_comp' is not getting updated

bayes_ctrl <-
  control_bayes(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE,
    verbose = T
  )

plan(multisession)
tune_bayes(normalizeds1 %>% extract_workflow("norm_FD"), seed = 1503, resamples = foldss1, metrics = metric_set(roc_auc, brier_class, kap, accuracy),  iter = 25, verbose = T, initial = 11, control = bayes_ctrl)
#> Error in `dials::grid_latin_hypercube()`:
#> ✖ This argument contains unknowns: `num_terms`.
#> ℹ See the `dials::finalize()` function.
# Error: This argument contains unknowns: `num_terms`.
plan(sequential)

Created on 2024-06-20 with reprex v2.1.0

Session info ``` r sessionInfo() #> R version 4.4.0 (2024-04-24) #> Platform: aarch64-apple-darwin20 #> Running under: macOS Sonoma 14.5 #> #> Matrix products: default #> BLAS: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib #> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0 #> #> locale: #> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8 #> #> time zone: UTC #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] future_1.33.2 AppliedPredictiveModeling_1.1-7 #> [3] discrim_1.0.1 lubridate_1.9.3 #> [5] forcats_1.0.0 stringr_1.5.1 #> [7] readr_2.1.5 tidyverse_2.0.0 #> [9] yardstick_1.3.1 workflowsets_1.1.0 #> [11] workflows_1.1.4 tune_1.2.1 #> [13] tidyr_1.3.1 tibble_3.2.1 #> [15] rsample_1.2.1 recipes_1.0.10 #> [17] purrr_1.0.2 parsnip_1.2.1 #> [19] modeldata_1.3.0 infer_1.0.7 #> [21] ggplot2_3.5.1 dplyr_1.1.4 #> [23] dials_1.2.1 scales_1.3.0 #> [25] broom_1.0.6 tidymodels_1.2.0 #> #> loaded via a namespace (and not attached): #> [1] rlang_1.1.4 magrittr_2.0.3 furrr_0.3.1 #> [4] rpart.plot_3.1.2 compiler_4.4.0 vctrs_0.6.5 #> [7] reshape2_1.4.4 lhs_1.1.6 pkgconfig_2.0.3 #> [10] fastmap_1.2.0 backports_1.5.0 utf8_1.2.4 #> [13] rmarkdown_2.27 prodlim_2023.08.28 tzdb_0.4.0 #> [16] xfun_0.44 reprex_2.1.0 styler_1.10.3 #> [19] parallel_4.4.0 cluster_2.1.6 R6_2.5.1 #> [22] CORElearn_1.57.3 stringi_1.8.4 parallelly_1.37.1 #> [25] rpart_4.1.23 Rcpp_1.0.12 iterators_1.0.14 #> [28] knitr_1.47 future.apply_1.11.2 R.utils_2.12.3 #> [31] Matrix_1.7-0 splines_4.4.0 nnet_7.3-19 #> [34] R.cache_0.16.0 timechange_0.3.0 tidyselect_1.2.1 #> [37] rstudioapi_0.16.0 yaml_2.3.8 timeDate_4032.109 #> [40] codetools_0.2-20 listenv_0.9.1 lattice_0.22-6 #> [43] plyr_1.8.9 withr_3.0.0 evaluate_0.23 #> [46] survival_3.7-0 pillar_1.9.0 foreach_1.5.2 #> [49] ellipse_0.5.0 generics_0.1.3 hms_1.1.3 #> [52] munsell_0.5.1 plotmo_3.6.3 globals_0.16.3 #> [55] class_7.3-22 glue_1.7.0 mda_0.5-4 #> [58] tools_4.4.0 data.table_1.15.4 gower_1.0.1 #> [61] fs_1.6.4 grid_4.4.0 plotrix_3.8-4 #> [64] ipred_0.9-14 colorspace_2.1-0 earth_5.3.3 #> [67] Formula_1.2-5 cli_3.6.2 DiceDesign_1.10 #> [70] fansi_1.0.6 lava_1.8.0 gtable_0.3.5 #> [73] R.methodsS3_1.8.2 GPfit_1.0-8 digest_0.6.35 #> [76] htmltools_0.5.8.1 R.oo_1.26.0 lifecycle_1.0.4 #> [79] hardhat_1.4.0 MASS_7.3-60.2 ```