Issue with future.callr and tuning a grid via tidymodels #28

Open WouterDH-UZL opened 7 months ago

WouterDH-UZL commented 7 months ago

Describe the bug I wasn't sure whether to post this here or in the 'tidymodels' repo, but using future.callr does not seem to work in conjunction with tune::tune_grid(). I've been using future.callr to stop grid searches and model building from consuming all GPU RAM, and it's been working great thus far. However, combining this approach with the tidymodels framework leads to issues.

Reproduce example


xgb_build <- boost_tree(
  trees = 30L,
  stop_iter = 3L,
  tree_depth = 2L, 
  min_n = 3L, 
  loss_reduction = tune(),                    
  sample_size = tune(), 
  mtry = tune(),        
  learn_rate = tune(),                        
) %>% 
  set_engine("xgboost", booster = "gbtree", objective = "reg:squarederror") %>% 

#~ Generate recipe
xgb_recipe <- recipe(
  Sepal.Length ~ .,
  data = iris
) %>%

#~ Generate hypergrid
xgb_hypergrid <- grid_latin_hypercube(
  sample_size = sample_prop(c(0.7, 1)),
  finalize(mtry(), iris),
  size = 20

xgb_wf <- workflow() %>%
  add_recipe(xgb_recipe) %>%

xgb_folds <- vfold_cv(
  v = 5,
  repeats = 3

# 'multisession' works
plan(multisession, workers = 4)
xgb_tune_multisession <- tune_grid(
  resamples = xgb_folds,
  grid = xgb_hypergrid,
  metrics = metric_set(rmse),
  control = control_grid(extract = extract_fit_engine, allow_par = TRUE)

# 'cluster' works
cl <- parallel::makeCluster(4)
plan(cluster, workers = cl)
xgb_tune_multisession <- tune_grid(
  resamples = xgb_folds,
  grid = xgb_hypergrid,
  metrics = metric_set(rmse),
  control = control_grid(extract = extract_fit_engine, allow_par = TRUE)

# 'callr' fails
plan(callr, workers = 4)
xgb_tune <- tune_grid(
  resamples = xgb_folds,
  grid = xgb_hypergrid,
  metrics = metric_set(rmse),
  control = control_grid(extract = extract_fit_engine, allow_par = TRUE)

Using tune_grid() with plan(callr) leads to an error that seems to be due to a namespacing issue:

Error in UseMethod("load_pkgs"): no applicable method for 'load_pkgs' applied to an object of class "function"`

Expected behavior The function tune::tune_grid() has support for multithreading via futures (using 'doFuture'), so I expected plan(callr) to work without issue (as the other plan() calls do).

Session information

R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.6 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8     LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                  LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Brussels
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] yardstick_1.2.0    workflowsets_1.0.1 workflows_1.1.3    tune_1.1.2         tidyr_1.3.0        tibble_3.2.1       rsample_1.2.0      recipes_1.0.9      purrr_1.0.2       
[10] parsnip_1.1.1      modeldata_1.2.0    infer_1.0.5        ggplot2_3.4.4      dplyr_1.1.4        dials_1.2.0        scales_1.2.1       broom_1.0.5        tidymodels_1.1.1  
[19] future.callr_0.8.2 doFuture_1.0.1     foreach_1.5.2      future_1.33.1     

loaded via a namespace (and not attached):
 [1] gtable_0.3.4        processx_3.8.3      lattice_0.22-5      callr_3.7.3         vctrs_0.6.4         tools_4.3.1         ps_1.7.5            generics_0.1.3     
 [9] parallel_4.3.1      fansi_1.0.5         pkgconfig_2.0.3     Matrix_1.6-1.1      data.table_1.14.10  lhs_1.1.6           GPfit_1.0-8         lifecycle_1.0.3    
[17] compiler_4.3.1      munsell_0.5.0       codetools_0.2-19    DiceDesign_1.10     class_7.3-22        prodlim_2023.08.28  pillar_1.9.0        furrr_0.3.1        
[25] MASS_7.3-60         gower_1.0.1         iterators_1.0.14    rpart_4.1.21        parallelly_1.36.0   lava_1.7.3          tidyselect_1.2.0    digest_0.6.34      
[33] listenv_0.9.0       splines_4.3.1       grid_4.3.1          colorspace_2.1-0    cli_3.6.1           magrittr_2.0.3      survival_3.5-7      utf8_1.2.3         
[41] future.apply_1.11.1 withr_2.5.1         backports_1.4.1     lubridate_1.9.3     timechange_0.2.0    globals_0.16.2      nnet_7.3-19         timeDate_4032.109  
[49] hardhat_1.3.0       rlang_1.1.3         Rcpp_1.0.11         glue_1.6.2          ipred_0.9-14        rstudioapi_0.15.0   R6_2.5.1           
simonpcouch commented 4 months ago

Dropping in to confirm that I see this as well and that this is possibly related to #23 and HenrikBengtsson/doFuture#82. The issue in this case is that load_pkgs() is called as load_pkgs(workflow)--where workflow is a local variabe--but workflow is being passed as workflows::workflow. We see the aforementioned error rather than object '' not found because even though the worker isn't able to find the local variable, it can find the function exported from workflows.