markjrieke / workboots

Generate bootstrap :hiking_boot: prediction intervals from a tidymodels workflow!
https://markjrieke.github.io/workboots/
Other
20 stars 1 forks source link

general speed improvements needed #39

Open markjrieke opened 2 years ago

markjrieke commented 2 years ago

Right now, workboots doesn't do anything special in terms of parallelization for speed improvements. Replacing the purrr::map() functions with furrr::future_map() functions under the hood doesn't actually speed things up (I believe this may be because parsnip sets separate environments when fitting/predicting, but I am totally out of my depth there). Some model types (e.g., boosted models) are pretty inherently sequential, but I feel like there's a potential for speed improvements via parallelization ~somewhere~

Any help here would be appreciated

spsanderson commented 2 years ago

I am looking to see what kind of speed up I get from using modeltime::parallel_start() which I think uses parallel package and using the tidymodels time framework of

rsamp <- bootstraps(mtcars, times = 2000)
keep_pred <- control_resamples(
  save_pred = TRUE, 
  save_workflow = TRUE, 
  parallel_over = "everything"
)
spsanderson commented 2 years ago

This what I have so far:

Libs

pacman::p_load(
  "tidyverse",
  "tidymodels",
  "modeltime",
  "timetk",
  "workboots",
  "tictoc"
)

No use of parallel

> wf <- workflow() %>%
+   add_recipe(recipe(qsec ~ wt, data = mtcars)) %>%
+   add_model(linear_reg())
+
> tic()
> set.seed(123)
> output <- wf %>%
+   predict_boots(n = 2000, training_data = mtcars, new_data = mtcars)
There were 50 or more warnings (use warnings() to see the first 50)
> toc()
329.85 sec elapsed

Modeltime parallel_start()

parallel_start(5)
tic()
set.seed(123)
output <- wf %>%
  predict_boots(n = 2000, training_data = mtcars, new_data = mtcars)
toc()
parallel_stop()

> parallel_start(5)
> tic()
> set.seed(123)
> output <- wf %>%
+   predict_boots(n = 2000, training_data = mtcars, new_data = mtcars)
There were 50 or more warnings (use warnings() to see the first 50)
> toc()
300.63 sec elapsed
> parallel_stop()

Only a 29 second speed up, which tells me something was probably cached and true parallel processing is probably not taking place.

Tidymodels approach

rsamp <- bootstraps(mtcars, times = 2000)
keep_pred <- control_resamples(
  save_pred = TRUE, 
  save_workflow = TRUE, 
  parallel_over = "everything"
)
tic()
wf %>% 
  fit_resamples(rsamp, control = keep_pred) %>%
  collect_predictions()
toc()

# A tibble: 23,077 x 5
   id            .pred  .row  qsec .config             
   <chr>         <dbl> <int> <dbl> <chr>               
 1 Bootstrap0001  18.2     5  17.0 Preprocessor1_Model1
 2 Bootstrap0001  18.2     7  15.8 Preprocessor1_Model1
 3 Bootstrap0001  18.2    12  17.4 Preprocessor1_Model1
 4 Bootstrap0001  18.2    13  17.6 Preprocessor1_Model1
 5 Bootstrap0001  18.2    14  18   Preprocessor1_Model1
 6 Bootstrap0001  18.1    17  17.4 Preprocessor1_Model1
 7 Bootstrap0001  18.4    19  18.5 Preprocessor1_Model1
 8 Bootstrap0001  18.2    24  15.4 Preprocessor1_Model1
 9 Bootstrap0001  18.2    25  17.0 Preprocessor1_Model1
10 Bootstrap0001  18.3    26  18.9 Preprocessor1_Model1
# ... with 23,067 more rows
Warning messages:
1: In names(x) : closing unused connection 7 (<-FIN-MS-05.BMHMC.ORG:11948)
2: In names(x) : closing unused connection 6 (<-FIN-MS-05.BMHMC.ORG:11948)
3: In names(x) : closing unused connection 5 (<-FIN-MS-05.BMHMC.ORG:11948)
4: In names(x) : closing unused connection 4 (<-FIN-MS-05.BMHMC.ORG:11948)
5: In names(x) : closing unused connection 3 (<-FIN-MS-05.BMHMC.ORG:11948)
> toc()
366.41 sec elapsed

Performed even worse than base. So some type of like parallel::parLapply needs to happen inside the function itself.

Here is another approach using furrr::future_map

tic()
set.seed(123)
furrr_test <- furrr::future_map(
  .x = rsamp,
  .f = ~ wf %>% fit_resamples(rsamp, control = keep_pred)
)
toc()

773.21 sec elapsed
> furrr_test
$splits
# Resampling results
# Bootstrap sampling 
# A tibble: 2,000 x 5
   splits          id            .metrics         .notes           .predictions     
   <list>          <chr>         <list>           <list>           <list>           
 1 <split [32/11]> Bootstrap0001 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [11 x 4]>
 2 <split [32/9]>  Bootstrap0002 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [9 x 4]> 
 3 <split [32/10]> Bootstrap0003 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [10 x 4]>
 4 <split [32/14]> Bootstrap0004 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [14 x 4]>
 5 <split [32/11]> Bootstrap0005 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [11 x 4]>
 6 <split [32/8]>  Bootstrap0006 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [8 x 4]> 
 7 <split [32/12]> Bootstrap0007 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [12 x 4]>
 8 <split [32/13]> Bootstrap0008 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [13 x 4]>
 9 <split [32/15]> Bootstrap0009 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [15 x 4]>
10 <split [32/12]> Bootstrap0010 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [12 x 4]>
# ... with 1,990 more rows

$id
# Resampling results
# Bootstrap sampling 
# A tibble: 2,000 x 5
   splits          id            .metrics         .notes           .predictions     
   <list>          <chr>         <list>           <list>           <list>           
 1 <split [32/11]> Bootstrap0001 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [11 x 4]>
 2 <split [32/9]>  Bootstrap0002 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [9 x 4]> 
 3 <split [32/10]> Bootstrap0003 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [10 x 4]>
 4 <split [32/14]> Bootstrap0004 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [14 x 4]>
 5 <split [32/11]> Bootstrap0005 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [11 x 4]>
 6 <split [32/8]>  Bootstrap0006 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [8 x 4]> 
 7 <split [32/12]> Bootstrap0007 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [12 x 4]>
 8 <split [32/13]> Bootstrap0008 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [13 x 4]>
 9 <split [32/15]> Bootstrap0009 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [15 x 4]>
10 <split [32/12]> Bootstrap0010 <tibble [2 x 4]> <tibble [0 x 3]> <tibble [12 x 4]>
# ... with 1,990 more rows

So ~773 seconds

This method with apparent proper use still take 396 seconds

future::plan(future::multisession(), workers = 5)
tic()
set.seed(123)
furrr_test <- furrr::future_map(
  .x = rsamp,
  .f = ~ wf %>% fit_resamples(rsamp, control = keep_pred)
)
toc()

May want to look into doParallel since the bootstrapping is a loop process in your code.

Also take a look here: https://www.r-bloggers.com/2015/12/speeding-up-the-cluster-bootstrap-in-r/

markjrieke commented 2 years ago

thanks for digging into this in pretty holistic detail! I'll definitely look into the resource you listed at the end. One thing I stumbled across in furrr's documentation is this nugget on common package development errors --- possible that the future test I ran didn't offer any improvement because of the devtools::load_all() issue mentioned in the docs (though, given the last example you provided w/furrr, it may still not improve speed). We'll see --- still lots to explore!

spsanderson commented 2 years ago

No problem at all

On Fri, Apr 22, 2022 at 11:11 AM Mark Rieke @.***> wrote:

thanks for digging into this in pretty holistic detail! I'll definitely look into the resource you listed at the end. One thing I stumbled across in furrr's documentation is this nugget https://furrr.futureverse.org/articles/articles/gotchas.html#package-development on common package development errors --- possible that the future test I ran didn't offer any improvement because of the devtools::load_all() issue mentioned in the docs (though, given the last example you provided w/furrr, it may still not improve speed). We'll see --- still lots to explore!

— Reply to this email directly, view it on GitHub https://github.com/markjrieke/workboots/issues/39#issuecomment-1106611164, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAPCNS322J5DUFOJ2ZQ2WLLVGK6SRANCNFSM5TWK6CNQ . You are receiving this because you commented.Message ID: @.***>

-- Steven P Sanderson II, MPH Book on Lulu http://goo.gl/lmrlFI Personal Site http://www.spsanderson.com

spsanderson commented 2 years ago

May want to look at the boot library as they have a function cv.Glm() that works with glm() which i suppose would be a better use than lm()

markjrieke commented 2 years ago

Some thoughts on speed

Longest process is fitting & predicting. We just need, however, enough fits for a reasonable estimation for the residual std deviation. Using the notation from algorithm 6.4 from Bootstrap Methods and their Application, workboots currently creates $R$ models, then samples $M = 1$ residuals for each observation. The book notes that it's reasonable to set $M = 1$ (which is why workboots is setup this way) but so long as $RM$ is sufficiently large the method still holds.

Possibly can cut down the default number of models to create and ramp up sampling from the residual distribution (probably keep ~2000). In Improvements on Cross Validation: the Bootstrap 0.632+ Method, Efron and Tibshirani use 50 resamples for their estimation of prediction error, though I really need to dig further on this to be sure. If that works out & is consistent with the current setup, can make the switch (& possibly keep the models for new predictions?) as a part of 0.3.0 release. Lots to unpack there.

spsanderson commented 2 years ago

See the example at this link, it works pretty fast with 2000 bootstraps samples:

https://www.tidymodels.org/learn/statistics/bootstrap/

This too, maybe modelr:

https://padpadpadpad.github.io/post/bootstrapping-non-linear-regressions-with-purrr/

Also this works fast but has it's own issues https://stats.stackexchange.com/questions/226565/bootstrap-prediction-interval:

library(tidyverse)
library(tictoc)

n <- 2000
pred <- numeric(0)
df <- mtcars
> tic()
> for (i in 1:n){
+   boot <- sample(nrow(df), n, replace = TRUE, orig.ids = TRUE)
+   fit <- lm(mpg ~ wt, data = df[boot,])
+   pred[i] <- predict(fit, newdata = df_test) %>%
+     as.data.frame() %>%
+     rownames_to_column() %>%
+     as_tibble() %>%
+     rename(".pred" = ".") %>%
+     mutate(.pred = .pred + sample(resid(fit), size = nrow(df_test))) %>%
+     nest(data = everything())
+ }
> toc()
16.54 sec elapsed
> pred_df <- map_dfr(pred, ~ .x[[1]] %>% as_tibble)
> pred_df %>%
+   group_by(rowname) %>%
+   summarise(mean_pred = mean(.pred))
# A tibble: 7 × 2
  rowname        mean_pred
  <chr>              <dbl>
1 AMC Javelin         18.8
2 Fiat 128            25.6
3 Ford Pantera L      20.2
4 Hornet 4 Drive      20.0
5 Merc 230            20.4
6 Merc 280            19.0
7 Valiant             18.9

The issue I think is going to be getting the exact function call out of the workflow and being able to pass it to predict

topepo commented 10 months ago

I’m finally getting around to using this and parallelism was going to be a suggestion.

I think that using furrr or futures is a good idea; it should generate significant speedups. It is an “embarrassingly parallel” computing problem and it certainly helps during resampling in tune, finetune, and so on.

We have some api’s that might help with the worker environments. required_pkgs() can be used on the workflow to get the tidymodels and engine package requirements. That list can be used to load packages in the workers (as well as checking that they are installed).