tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
273 stars 42 forks source link

select_best() fails after tuning with manually defined grid_folds #886

Closed joshuagi closed 5 months ago

joshuagi commented 5 months ago

The problem

I've encountered an issue that I think is tied to the collect_metrics() function. I am manually assigning the rows in my dataframe into different folds for cross-validation of the lambda hyperparameter of lasso. After tuning with my custom 'grid_folds', collect_metrics shows all NAs for the metric I am using (roc_auc) to select the best lambda hyperparameter. I think select_best() fails because of this.

This issue came up after I updated my tidymodels software on my local machine today. However, I can run my analyses on my university HPC without any issues (which has the older software).

Below I've provided a reproducible example of my issue. I ran it on both my local machine and on my HPC. The last step to select the best lambda does not run on my local machine, but it does run on my HPC. I've included session info for both my local machine and my HPC.

thank you!!

josh

Reproducible example

library(reshape2)
library(tidymodels)
library(stringr)
library(dplyr)
library(healthyR.ai)
library(doParallel)
library(finetune)
library(vip) 

trainingdata <- structure(list(ID = c(24L, 72L, 72L, 100L, 125L, 125L, 132L, 
                                      132L, 161L, 169L, 182L, 182L, 183L, 188L, 188L, 189L, 209L, 226L, 
                                      226L, 234L, 234L, 236L, 239L, 240L, 240L, 241L, 241L, 248L, 248L, 
                                      255L, 255L), timepoint = c("T1", "T1", "T2", "T1", "T1", "T2", 
                                                                 "T1", "T2", "T1", "T1", "T1", "T2", "T1", "T1", "T2", "T1", "T1", 
                                                                 "T1", "T2", "T1", "T2", "T1", "T2", "T1", "T2", "T1", "T2", "T1", 
                                                                 "T2", "T1", "T2"), group = structure(c(2L, 2L, 2L, 2L, 1L, 1L, 
                                                                                                        1L, 1L, 1L, 1L, 2L, 2L, 1L, 2L, 2L, 1L, 2L, 1L, 1L, 1L, 1L, 1L, 
                                                                                                        1L, 1L, 1L, 2L, 2L, 1L, 1L, 1L, 1L), levels = c("0", "1"), class = "factor"), 
                               V1 = c(0.0414340209999999, 0.025224471, 0.04231266, 0.039416785, 
                                      -0.100547952, 0.00753425800000002, 0.053571377, -0.055701425, 
                                      0.187008187, 0.084571454, 0.021884699, 0.011765373, 0.092953308, 
                                      0.082803165, 0.165637544, -0.011435301, -0.020401237, -0.023440767, 
                                      0.037270655, 0.0934873380000001, 0.056128516, 0.004593027, 
                                      -0.081403072, -0.058107866, 0.00582489999999999, 0.011055028, 
                                      -0.010955359, 0.028924498, -0.036961629, -0.018467265, -0.140053071
                               ), V2 = c(0.096589547, 0.00942918099999995, -0.023681238, 
                                         0.00248310499999999, 0.01989384, -0.00939109799999999, 0.047211409, 
                                         -0.036252554, 0.011121487, 0.049359805, 0.08612576, 0.02894711, 
                                         0.018181926, 0.032300977, 0.05097082, 0.011368588, 0.04922788, 
                                         0.00207389300000002, 0.045706516, 0.038476003, 0.035818352, 
                                         0.092883021, -0.014287559, 0.022199844, 0.041880865, 0.024305687, 
                                         -0.00444903299999999, 0.033323643, 0.08933341, -0.039026361, 
                                         -0.038871654), V3 = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
                                                               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), 
                               V4 = c(0.00812114000000008, -0.00212878200000066, 0.0268026419999998, 
                                      -0.0564599980000002, -0.0698764079999998, 0.0424415600000003, 
                                      0.0233709270000002, -0.10382344, 0.0187751650000001, 0.154421979, 
                                      -0.0506083460000002, -0.114158665, 0.0224504620000001, -0.0698718649999996, 
                                      -0.0657564419999996, -0.0441098890000005, -0.0456882009999999, 
                                      -0.0471626139999994, 0.0133814889999995, 0.0280674359999997, 
                                      -0.106653236000001, -0.0014476160000001, -0.0541871250000003, 
                                      -0.123672098, 0.000489790000000045, -0.0311820410000001, 
                                      0.0545155729999998, -0.000270543000000067, -0.0456917360000002, 
                                      0.00655041100000009, -0.0635483560000001), V5 = c(0.01922686, 
                                                                                        -0.00928328900000001, 0.015601901, 0, -0.004813204, -0.007899604, 
                                                                                        -0.010793222, -0.00902706, 0.016550936, 0, 0.0410770869999999, 
                                                                                        0.010523249, 0, 0.00906663600000002, 0, 0, 0, -0.011362604, 
                                                                                        0, 0.013091598, 0.016162166, 0, -0.027932629, 0, -0.021681544, 
                                                                                        0, 0, -0.000248178, 0.007455193, -0.033860835, -0.047112061
                                      ), V6 = c(0.252536383, -0.191686357, 0.193634125, -0.112218412, 
                                                -0.18883989, -0.027855115, -0.108012743, -0.48960312, 0.184219795, 
                                                0.063316516, 0.020122683, -0.112871676, -0.029190738, 0.022005026, 
                                                0.26734916, -0.163538612, -0.124999495, -0.238249848, 0.288404799, 
                                                0.242375702, -0.0407274099999999, 0.074124986, -0.597763213, 
                                                -0.211803165, -0.031423014, -0.016929837, -0.142037141, -0.064138229, 
                                                0.033956957, -0.437783971, -0.552079896), V7 = c(0.106550431, 
                                                                                                 0.0417641999999999, 0.091150763, -0.048006197, -0.113301346, 
                                                                                                 0.034946745, -0.00293663, -0.194671822, 0.066748464, -0.0312702269999999, 
                                                                                                 0.069282392, -0.033883316, -0.014992392, 0.023497632, 0.072943182, 
                                                                                                 -0.066608122, -0.239288944, -0.083516309, 0.049463438, 0.070188417, 
                                                                                                 0.026881999, 0.086103675, -0.244003518, -0.124041214, 0.00447314899999995, 
                                                                                                 -0.019669773, -0.248940058, -0.017095608, -0.066120953, -0.133190457, 
                                                                                                 -0.180641605)), row.names = c(NA, -31L), class = c("tbl_df", 
                                                                                                                                                    "tbl", "data.frame"))

# Model spec
lasso_reg_spec <-
  logistic_reg(mode = "classification", penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")

# Case weights
weighted <- TRUE
train.y <-  trainingdata$group
fraction_0 <- rep(1 - sum(train.y == 0) / length(train.y), sum(train.y == 0))
fraction_1 <- rep(1 - sum(train.y == 1) / length(train.y), sum(train.y == 1))
weights <- numeric(length(train.y))
if (weighted == TRUE) {
  weights[train.y == 0] <- fraction_0
  weights[train.y == 1] <- fraction_1
} else {
  weights <- rep(1, length(train.y))
}
trainingdata <-  trainingdata %>%
  mutate(
    case_wts = weights, # Assign the weights
    case_wts = importance_weights(case_wts) # make it an importance weight
  )

# Recipe
rec <- recipe(group ~ ., data = trainingdata) %>%
  update_role(ID, new_role = "ID") %>%
  update_role(timepoint, new_role = "timepoint") %>%
  step_impute_median(all_predictors()) %>%
  step_hai_winsorized_truncate(all_predictors(), fraction = 0.05) %>% 
  step_rm(starts_with("raw++")) %>% 
  step_zv(all_predictors(), group = "group") %>% 
  step_center(all_predictors()) %>%
  step_scale(all_predictors())

# Workflowset
wfset <-
  workflow_set(
    preproc = list(preprocess = rec),
    models = list(lasso_reg_spec = lasso_reg_spec
    ),
    case_weights = case_wts # Use class weights to fit the model
  )

# Custom resampling
foldid_cv <- unique(trainingdata$ID)

indices <- list()
for(element in 1:length(foldid_cv)){
  id <- foldid_cv[element]
  df_tmp <- trainingdata %>%
    select(ID) %>%
    mutate(index = rownames(.)) %>%
    mutate(index = as.integer(index))

  analysis <- df_tmp %>%
    filter(ID != id) %>%
    select(index) %>%
    unlist %>%
    unname

  assessment <- df_tmp %>%
    filter(ID == id) %>%
    select(index) %>%
    unlist %>%
    unname

  indices[[element]] <- list(analysis = analysis,
                             assessment = assessment)
}
splits <- lapply(indices, make_splits, data = trainingdata)
grid_folds <- manual_rset(splits, paste0("Fold", seq(1:length(splits))))

# run tuning
grid_ctrl <- control_grid(
  verbose = TRUE,
  save_pred = TRUE,
  save_workflow = FALSE,
  parallel_over = "everything"
)

tune_results <-
  wfset %>%
  workflow_map(
    "tune_grid",
    seed = 2023,
    resamples = grid_folds,
    grid = 100,
    control = grid_ctrl,
    verbose = TRUE,
    metrics = metric_set(roc_auc)
  )

# This fails
tune_tmp <- tune_results %>%
  extract_workflow_set_result("preprocess_lasso_reg_spec") %>%
  select_best(metric = "roc_auc")
# Session Info Local
R version 4.3.3 (2024-02-29)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.4.1

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] glmnet_4.1-8       Matrix_1.6-5       vip_0.4.1          finetune_1.2.0     doParallel_1.0.17  iterators_1.0.14   foreach_1.5.2      healthyR.ai_0.0.13 stringr_1.5.1      reshape2_1.4.4    
[11] yardstick_1.3.1    workflowsets_1.1.0 workflows_1.1.4    tune_1.2.0         tidyr_1.3.1        tibble_3.2.1       rsample_1.2.1      recipes_1.0.10     purrr_1.0.2        parsnip_1.2.1     
[21] modeldata_1.3.0    infer_1.0.7        ggplot2_3.5.0      dplyr_1.1.4        dials_1.2.1        scales_1.3.0       broom_1.0.5        tidymodels_1.2.0  

loaded via a namespace (and not attached):
 [1] rlang_1.1.3         magrittr_2.0.3      furrr_0.3.1         compiler_4.3.3      vctrs_0.6.5         lhs_1.1.6           pkgconfig_2.0.3     shape_1.4.6.1       fastmap_1.1.1       backports_1.4.1    
[11] ellipsis_0.3.2      utf8_1.2.4          promises_1.3.0      rmarkdown_2.26      prodlim_2023.08.28  xfun_0.43           later_1.3.2         tweenr_2.0.3        prettyunits_1.2.0   R6_2.5.1           
[21] stringi_1.8.3       parallelly_1.37.1   rpart_4.1.23        lubridate_1.9.3     Rcpp_1.0.12         knitr_1.45          future.apply_1.11.2 httpuv_1.6.15       splines_4.3.3       nnet_7.3-19        
[31] timechange_0.3.0    tidyselect_1.2.1    rstudioapi_0.16.0   yaml_2.3.8          timeDate_4032.109   codetools_0.2-19    miniUI_0.1.1.1      listenv_0.9.1       lattice_0.22-5      plyr_1.8.9         
[41] shiny_1.8.1.1       withr_3.0.0         evaluate_0.23       future_1.33.2       survival_3.5-8      polyclip_1.10-6     pillar_1.9.0        generics_0.1.3      munsell_0.5.1       globals_0.16.3     
[51] xtable_1.8-4        class_7.3-22        glue_1.7.0          tools_4.3.3         data.table_1.15.4   gower_1.0.1         cowplot_1.1.3       grid_4.3.3          ipred_0.9-14        colorspace_2.1-0   
[61] ggforce_0.4.2       cli_3.6.2           DiceDesign_1.10     fansi_1.0.6         lava_1.8.0          gtable_0.3.4        GPfit_1.0-8         digest_0.6.35       farver_2.1.1        htmltools_0.5.8.1  
[71] lifecycle_1.0.4     hardhat_1.3.1       mime_0.12           ggExtra_0.10.1      MASS_7.3-60.0.1   
# Session Info HPC
> sessionInfo()
R version 4.2.0 (2022-04-22)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)
Matrix products: default
BLAS/LAPACK: /share/software/user/open/openblas/0.3.10/lib/libopenblas_haswellp-r0.3.10.so
locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods  
[8] base     
other attached packages:
 [1] glmnet_4.1-8       Matrix_1.4-1       vip_0.4.1          finetune_1.1.0    
 [5] doParallel_1.0.17  iterators_1.0.14   foreach_1.5.2      healthyR.ai_0.0.13
 [9] stringr_1.5.0      yardstick_1.2.0    workflowsets_1.0.1 workflows_1.1.3   
[13] tune_1.1.2         tidyr_1.3.0        tibble_3.2.1       rsample_1.2.0     
[17] recipes_1.0.8      purrr_1.0.2        parsnip_1.1.1      modeldata_1.2.0   
[21] infer_1.0.5        ggplot2_3.4.4      dplyr_1.1.3        dials_1.2.0       
[25] scales_1.2.1       broom_1.0.5        tidymodels_1.1.1   reshape2_1.4.4    
loaded via a namespace (and not attached):
 [1] Rcpp_1.0.11         lubridate_1.9.3     lattice_0.20-45    
 [4] listenv_0.9.0       prettyunits_1.2.0   class_7.3-20       
 [7] digest_0.6.29       ipred_0.9-14        utf8_1.2.2         
[10] parallelly_1.36.0   R6_2.5.1            plyr_1.8.9         
[13] backports_1.4.1     hardhat_1.3.0       pillar_1.9.0       
[16] rlang_1.1.1         rstudioapi_0.15.0   data.table_1.14.8  
[19] DiceDesign_1.9      furrr_0.3.1         rpart_4.1.16       
[22] splines_4.2.0       gower_1.0.1         munsell_0.5.0      
[25] compiler_4.2.0      pkgconfig_2.0.3     shape_1.4.6        
[28] globals_0.16.2      nnet_7.3-17         tidyselect_1.2.0   
[31] prodlim_2023.08.28  codetools_0.2-18    GPfit_1.0-8        
[34] fansi_1.0.3         future_1.33.0       withr_2.5.0        
[37] MASS_7.3-56         grid_4.2.0          gtable_0.3.4       
[40] lifecycle_1.0.3     magrittr_2.0.3      future.apply_1.11.0
[43] cli_3.6.1           stringi_1.7.8       timeDate_4022.108  
[46] ellipsis_0.3.2      lhs_1.1.6           generics_0.1.3     
[49] vctrs_0.6.4         lava_1.7.2.1        tools_4.2.0        
[52] glue_1.6.2          survival_3.3-1      timechange_0.2.0   
[55] colorspace_2.1-0   
github-actions[bot] commented 4 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.