stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

CV parallelization #422

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

This addresses the major concern from #77, namely to parallelize the fold-wise searches and performance evaluations. In contrast to #77, the fold-wise searches and performance evaluations in case of K-fold CV are also parallelized (so, in short, validate_search = TRUE is parallelized). The idea from this comment in #77 is not implemented yet.

For details, see the commit messages and the NEWS.md entry added here.

fweber144 commented 1 year ago

Illustration:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
df_gaussian <- df_gaussian[1:24, ]
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$grp <- gl(n = 8, k = floor(nrow(dat) / 8),
              labels = paste0("gr", seq_len(8)))
set.seed(457211)
grp_icpts_truth <- rnorm(nlevels(dat$grp), sd = 6)
dat$y <- dat$y + grp_icpts_truth[as.numeric(dat$grp)]

# Reference model ---------------------------------------------------------

library(brms)
options(mc.cores = parallel::detectCores(logical = FALSE))
options(cmdstanr_write_stan_file_dir = ".")
rfit <- brm(y ~ X1 + X2 + X3 + (1 | grp),
            data = dat,
            seed = 1140350788,
            iter = 1000,
            refresh = 0)

# projpred ----------------------------------------------------------------

## Setup ------------------------------------------------------------------

library(projpred)
warn_orig <- options(warn = 1)

ncores <- 4L

if (!identical(.Platform$OS.type, "windows")) {
  dopar_backend <- "doParallel"
} else {
  dopar_backend <- "doFuture"
}
if (dopar_backend == "doParallel") {
  doParallel::registerDoParallel(ncores)
} else if (dopar_backend == "doFuture") {
  doFuture::registerDoFuture()
  export_default <- options(doFuture.foreach.export = ".export")
  # export_default <- options(
  #   doFuture.foreach.export = ".export-and-automatic-with-warning"
  # )
  if (!identical(.Platform$OS.type, "windows")) {
    future_plan <- "multicore"
  } else {
    # future_plan <- "multisession"
    future_plan <- "callr"
  }
  if (future_plan == "multicore") {
    future::plan(future::multicore, workers = ncores)
  } else if (future_plan == "multisession") {
    future::plan(future::multisession, workers = ncores)
  } else if (future_plan == "callr") {
    future::plan(future.callr::callr, workers = ncores)
  } else {
    stop("Unrecognized `future_plan`.")
  }
} else {
  stop("Unrecognized `dopar_backend`.")
}
stopifnot(identical(foreach::getDoParWorkers(), ncores))

### On Unix systems, it might make sense to limit memory usage (note that when
### forking, this applies to each of the parallel worker sessions; furthermore,
### all of these parallel worker sessions "inherit" the memory usage from the
### main session):
# unix::rlimit_as(1.3 * 1e9)
###

## Run cv_varsel() --------------------------------------------------------

### K-fold ----------------------------------------------------------------

# Sequential run:
system.time(
  cvvs_K_seq <- cv_varsel(rfit, cv_method = "kfold", K = 4, seed = 5602)
)
## Time (seconds):
# 287.566
##

# Parallel run:
system.time(
  cvvs_K_par <- cv_varsel(rfit, cv_method = "kfold", K = 4, seed = 5602,
                          parallel = TRUE)
)
## Time (seconds):
# 100.314
##

### PSIS-LOO --------------------------------------------------------------

# Sequential run:
system.time(
  cvvs_loo_seq <- cv_varsel(rfit, seed = 5602)
)
## Time (seconds):
# 1767.823
##

# Parallel run:
system.time(
  cvvs_loo_par <- cv_varsel(rfit, seed = 5602, parallel = TRUE)
)
## Time (seconds):
# 565.725
##

## Teardown ---------------------------------------------------------------

if (dopar_backend == "doParallel") {
  doParallel::stopImplicitCluster()
} else if (dopar_backend == "doFuture") {
  future::plan(future::sequential)
  options(export_default)
  rm(export_default)
} else {
  stop("Unrecognized `dopar_backend`.")
}

options(warn_orig)