tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
272 stars 42 forks source link

Parallel Speed-Ups Not As Expected #482

Closed pgoodling-usgs closed 2 years ago

pgoodling-usgs commented 2 years ago

Hello,

I'm working on parallelizing across many resamples with a large dataset. I ran into some of the issues described in #376 and #384 , and was glad to see that in Tune 0.2.0 there was a fix (described in #397). I reproduced the example in #397 below, though switched it to a ranger random forest model.

With 5,000 rows and 10,000 columns also see that with a single non-parallel worker, the usage peaks at ~3.2-3.5 GB and that with a 8 parallel workers each peaks at ~1.2GB-1.5GB (for a total memory usage between 9GB and 10GB).

However, the decrease in execution time isn't quite as advantageous as I would expect. With 24 resamples and 8 cores, each core should work through 3 iterations of the model fitting. Based on this post it seems difficult to estimate the time savings a priori, but I still expected a greater efficiency than going from 752 seconds elapsed to 558 seconds elapsed.

In Chapter 10 of the Tidymodeling book it says that

Parallel processing with the tune package tends to provide linear speed-ups for the first few cores. This means that, with two cores, the computations are twice as fast. Depending on the data and type of model, the linear speedup deteriorates after 4-5 cores. Using more cores will still reduce the time it takes to complete the task; there are just diminishing returns for the additional cores.

Let’s wrap up with one final note about parallelism. For each of these technologies, the memory requirements multiply for each additional core used. For example, if the current data set is 2 GB in memory and three cores are used, the total memory requirement is 8 GB (2 for each worker process plus the original). Using too many cores might cause the computations (and the computer) to slow considerably.

So I expect at least a 2-fold increase in computation speed, right? Is there something about the 0.2.0 upgrade that would have affected this? Is this typical behavior or am I missing something in my configuration or my understanding?

##### Libraries #####
require( "tidymodels" )
#> Loading required package: tidymodels
#> Warning: package 'tidymodels' was built under R version 4.1.3
#> Warning: package 'dials' was built under R version 4.1.3
#> Warning: package 'dplyr' was built under R version 4.1.3
#> Warning: package 'ggplot2' was built under R version 4.1.2
#> Warning: package 'infer' was built under R version 4.1.2
#> Warning: package 'modeldata' was built under R version 4.1.2
#> Warning: package 'parsnip' was built under R version 4.1.3
#> Warning: package 'recipes' was built under R version 4.1.3
#> Warning: package 'rsample' was built under R version 4.1.2
#> Warning: package 'tibble' was built under R version 4.1.2
#> Warning: package 'tidyr' was built under R version 4.1.3
#> Warning: package 'tune' was built under R version 4.1.3
#> Warning: package 'workflows' was built under R version 4.1.3
#> Warning: package 'workflowsets' was built under R version 4.1.3
#> Warning: package 'yardstick' was built under R version 4.1.2
require( "foreach" )
#> Loading required package: foreach
#> Warning: package 'foreach' was built under R version 4.1.3
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
require( "doParallel" )
#> Loading required package: doParallel
#> Warning: package 'doParallel' was built under R version 4.1.3
#> Loading required package: iterators
#> Warning: package 'iterators' was built under R version 4.1.3
#> Loading required package: parallel
require( "tictoc" )
#> Loading required package: tictoc
#> Warning: package 'tictoc' was built under R version 4.1.1
require( "ranger" )
#> Loading required package: ranger
#> Warning: package 'ranger' was built under R version 4.1.2
require( "tidyverse" )
#> Loading required package: tidyverse
#> Warning: package 'readr' was built under R version 4.1.3

set.seed( 4235 )

# Create data set
predictor_names <- str_c( "pred", 1:10000)
id <- 1:5000
data <- expand_grid( id, predictor_names ) %>%
  mutate( value = rnorm( n = n() ) ) %>%
  pivot_wider( names_from = predictor_names, values_from = value ) %>%
  mutate( outcome = as.factor( as.character( rbernoulli( n(), p = 0.3 ) ) ) ) %>%
  select( id, outcome, everything() )

# Set up model training parameters
n_iters <- 8
prop_training <- 0.66
n_fold_hyperparameter <- 3
lambdas_range <- 10^seq( -2, 0, by = 0.5 )

##### Set model type and engine #####
glm_model <-     rand_forest(trees = 100,
                             mtry = tune(), 
                             mode = "classification") %>%
  set_engine("ranger",
             importance="impurity", oob.error = TRUE)

##### Set data processing recipe #####
cur_recipe <- recipe( data ) %>%
  # General processing
  update_role( id, new_role = "ID" ) %>%

  # Specific outcome/predictor processing
  update_role( outcome, new_role = "outcome" ) %>%
  update_role( starts_with( "pred" ), new_role = "predictor" ) %>%
  step_zv( all_predictors() )

##### Set the workflow #####
cur_workflow <- workflow() %>%
  add_model( glm_model ) %>%
  add_recipe( cur_recipe )

##### Select training and testing data #####
cur_testing_training_splits <- initial_split( data, prop = prop_training, strata = outcome )

# Get testing data
cur_testing_data <- testing( cur_testing_training_splits )

# Get training data
cur_training_data <- training( cur_testing_training_splits )

##### Train the model #####
# Set training grid
cur_training_grid <- expand_grid(mtry=c(1,round(sqrt(ncol(cur_training_data)-1))))

# Set CV
cur_cv_folds <- vfold_cv( cur_training_data, v = n_fold_hyperparameter, repeats = n_iters, strata = outcome )

# Train model on a single core
tic( "Train model without parallelization" )
cur_training_results <- cur_workflow %>%
  tune_grid( resamples = cur_cv_folds,
             grid = cur_training_grid,
             metrics = metric_set( roc_auc, pr_auc, accuracy, npv, ppv, yardstick::sensitivity, yardstick::specificity, yardstick::precision, yardstick::recall ) )
toc()
#> Train model without parallelization: 752.33 sec elapsed

##### Train the model in parallel #####
# Initialize cores for parallel processing
ncores <- 8
cl <- makeCluster( ncores )
registerDoParallel( cl )

ctrl <- control_grid(verbose = F,
                     allow_par = T,
                     #pkg=c("lubridate","doParallel","tidyverse","tidymodels"),
                     parallel_over="resamples")

# Train model with parallelization
tic( "Train model with parallelization" )
cur_training_results_parallel <- cur_workflow %>%
  tune_grid( resamples = cur_cv_folds,
             grid = cur_training_grid,
             metrics = metric_set( roc_auc, pr_auc, accuracy, npv, ppv, yardstick::sensitivity, yardstick::specificity, yardstick::precision, yardstick::recall ),
             control=ctrl)
toc()
#> Train model with parallelization: 558.26 sec elapsed

# Clean up cores
stopCluster( cl )

Created on 2022-04-14 by the reprex package (v2.0.1)

UPDATE: I ran this code using just 2 cores and wound up with a completion time of 420 seconds and in each core the memory peaked around 2.5 GB. It running faster on 2 cores than on 8 cores is certainly unexpected and goes against the statement that "Using more cores will still reduce the time it takes to complete the task; there are just diminishing returns for the additional cores."

Here is my SessionInfo():

- Session info --------------------------------------------------------------------------------
 setting  value
 version  R version 4.1.0 (2021-05-18)
 os       Windows 10 x64 (build 19042)
 system   x86_64, mingw32
 ui       RStudio
 language (EN)
 collate  English_United States.1252
 ctype    English_United States.1252
 tz       America/New_York
 date     2022-04-14
 rstudio  1.4.1717 Juliet Rose (desktop)
 pandoc   2.11.4 @ C:/Program Files/RStudio/bin/pandoc/ (via rmarkdown)

- Packages ------------------------------------------------------------------------------------
 package      * version    date (UTC) lib source
 assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.2)
 backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.2)
 broom        * 0.7.12     2022-01-28 [1] CRAN (R 4.1.0)
 callr          3.7.0      2021-04-20 [1] CRAN (R 4.1.0)
 cellranger     1.1.0      2016-07-27 [1] CRAN (R 4.1.0)
 class          7.3-20     2022-01-13 [1] CRAN (R 4.1.2)
 cli            3.2.0      2022-02-14 [1] CRAN (R 4.1.3)
 clipr          0.8.0      2022-02-22 [1] CRAN (R 4.1.3)
 codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.0)
 colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.1.3)
 crayon         1.5.1      2022-03-26 [1] CRAN (R 4.1.3)
 DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
 dbplyr         2.1.1      2021-04-06 [1] CRAN (R 4.1.0)
 dials        * 0.1.0      2022-01-31 [1] CRAN (R 4.1.3)
 DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.0)
 digest         0.6.29     2021-12-01 [1] CRAN (R 4.1.2)
 doParallel   * 1.0.17     2022-02-07 [1] CRAN (R 4.1.3)
 dplyr        * 1.0.8      2022-02-08 [1] CRAN (R 4.1.3)
 ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
 evaluate       0.15       2022-02-18 [1] CRAN (R 4.1.3)
 fansi          1.0.3      2022-03-24 [1] CRAN (R 4.1.3)
 fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
 forcats      * 0.5.1      2021-01-27 [1] CRAN (R 4.1.0)
 foreach      * 1.5.2      2022-02-02 [1] CRAN (R 4.1.3)
 fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.2)
 furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.2)
 future         1.24.0     2022-02-19 [1] CRAN (R 4.1.3)
 future.apply   1.8.1      2021-08-10 [1] CRAN (R 4.1.2)
 generics       0.1.2      2022-01-31 [1] CRAN (R 4.1.3)
 ggplot2      * 3.3.5      2021-06-25 [1] CRAN (R 4.1.2)
 globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
 glue           1.6.2      2022-02-24 [1] CRAN (R 4.1.3)
 gower          1.0.0      2022-02-03 [1] CRAN (R 4.1.2)
 GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.0)
 gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.0)
 hardhat        0.2.0      2022-01-24 [1] CRAN (R 4.1.2)
 haven          2.4.3      2021-08-04 [1] CRAN (R 4.1.2)
 highr          0.9        2021-04-16 [1] CRAN (R 4.1.0)
 hms            1.1.1      2021-09-26 [1] CRAN (R 4.1.2)
 htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.1.2)
 httr           1.4.2      2020-07-20 [1] CRAN (R 4.1.0)
 infer        * 1.0.0      2021-08-13 [1] CRAN (R 4.1.2)
 ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.2)
 iterators    * 1.0.14     2022-02-05 [1] CRAN (R 4.1.3)
 jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.1.3)
 knitr          1.38       2022-03-25 [1] CRAN (R 4.1.3)
 lattice        0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
 lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.2)
 lhs            1.1.5      2022-03-22 [1] CRAN (R 4.1.3)
 lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.1.2)
 listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
 lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.1.2)
 magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.1.0)
 MASS           7.3-56     2022-03-23 [1] CRAN (R 4.1.3)
 Matrix         1.4-1      2022-03-23 [1] CRAN (R 4.1.3)
 modeldata    * 0.1.1      2021-07-14 [1] CRAN (R 4.1.2)
 modelr         0.1.8      2020-05-19 [1] CRAN (R 4.1.0)
 munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)
 nnet           7.3-17     2022-01-13 [1] CRAN (R 4.1.2)
 parallelly     1.30.0     2021-12-17 [1] CRAN (R 4.1.2)
 parsnip      * 0.2.1      2022-03-17 [1] CRAN (R 4.1.3)
 pillar         1.7.0      2022-02-01 [1] CRAN (R 4.1.3)
 pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
 plyr           1.8.7      2022-03-24 [1] CRAN (R 4.1.3)
 pROC           1.18.0     2021-09-03 [1] CRAN (R 4.1.2)
 processx       3.5.3      2022-03-25 [1] CRAN (R 4.1.3)
 prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
 ps             1.6.0      2021-02-28 [1] CRAN (R 4.1.0)
 purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
 R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.2)
 ranger       * 0.13.1     2021-07-14 [1] CRAN (R 4.1.2)
 Rcpp           1.0.8.3    2022-03-17 [1] CRAN (R 4.1.3)
 readr        * 2.1.2      2022-01-30 [1] CRAN (R 4.1.3)
 readxl         1.4.0      2022-03-28 [1] CRAN (R 4.1.0)
 recipes      * 0.2.0      2022-02-18 [1] CRAN (R 4.1.3)
 reprex         2.0.1      2021-08-05 [1] CRAN (R 4.1.2)
 rlang          1.0.2      2022-03-04 [1] CRAN (R 4.1.3)
 rmarkdown      2.13       2022-03-10 [1] CRAN (R 4.1.3)
 rpart          4.1.16     2022-01-24 [1] CRAN (R 4.1.2)
 rsample      * 0.1.1      2021-11-08 [1] CRAN (R 4.1.2)
 rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.0)
 rvest          1.0.2      2021-10-16 [1] CRAN (R 4.1.2)
 scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.1.0)
 sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.1.2)
 stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.2)
 stringr      * 1.4.0      2019-02-10 [1] CRAN (R 4.1.0)
 survival       3.3-1      2022-03-03 [1] CRAN (R 4.1.3)
 tibble       * 3.1.6      2021-11-07 [1] CRAN (R 4.1.2)
 tictoc       * 1.0.1      2021-04-19 [1] CRAN (R 4.1.1)
 tidymodels   * 0.2.0      2022-03-19 [1] CRAN (R 4.1.3)
 tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.1.3)
 tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.1.3)
 tidyverse    * 1.3.1      2021-04-15 [1] CRAN (R 4.1.0)
 timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)
 tune         * 0.2.0      2022-03-19 [1] CRAN (R 4.1.3)
 tzdb           0.3.0      2022-03-28 [1] CRAN (R 4.1.0)
 utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.2)
 vctrs          0.4.0      2022-03-30 [1] CRAN (R 4.1.0)
 withr          2.5.0      2022-03-03 [1] CRAN (R 4.1.3)
 workflows    * 0.2.6      2022-03-18 [1] CRAN (R 4.1.3)
 workflowsets * 0.2.1      2022-03-15 [1] CRAN (R 4.1.3)
 xfun           0.30       2022-03-02 [1] CRAN (R 4.1.3)
 xml2           1.3.3      2021-11-30 [1] CRAN (R 4.1.2)
 yaml           2.3.5      2022-02-21 [1] CRAN (R 4.1.2)
 yardstick    * 0.0.9      2021-11-22 [1] CRAN (R 4.1.2)

 [1] C:/Program Files/R/R-4.1.0/library

-----------------------------------------------------------------------------------------------
topepo commented 2 years ago

752/420 = 1.8 is close to a 2-fold speedup. 2-fold is the theoretical optimum speedup and is unlikely to happen.

On your machine, what are the results of parallel::detectCores()?

The issue with the 8 core run is that using n_fold_hyperparameter <- 3 limits the number of parallel works to be 3 (see next paragraph tho). Using 8 should not really give you >3-fold speedup and, if this is on windows, there might be some jumping around of processes across cores that might make is slower.

If you want to use more cores, there is an option in the control function to parallel_over = "everything". The blog post that you reference describes the different in the options and gives some intuition to which option is best (spoiler: it depends on what you are doing).

pgoodling-usgs commented 2 years ago

Aha- I didn't realize that the # of cores would be limited to n_fold_hypeparameter, I thought it would be the overall folds in cur_cv_folds.

With 2 cores and parallel = "everything" the runtime is 559 seconds, which is slower than the 420 seconds for parallel="resamples". With 8 cores parallel = "everything" the runtime is 801 seconds, which is slower than the 558 seconds for parallel="resamples" (and even slower than the unparallelized ~750 seconds).

When I run parallel::detectCores(), I get 8. So I probably shouldn't expect 8 cores to fully be engaged, right?

I guess overall the speedups are simply difficult to predict and are more related to underlying R / windows features. You've documented this in various blog posts so it might not really count as an issue with tune.

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.