While working on a reply to this Stan Discourse thread, I experienced a high peak memory usage during the forward search. This PR should alleviate this. An illustration using the code that brought me there (the example dataset may be found here; also note that the following code writes some files to the current working directory):
NB: Initially, I tried that code with ndraws = 1000 (and nclusters = NULL), but that crashed my R session and thereby my whole machine. The above code ran through on my machine (with 16 GB of RAM).
Another reprex (which doesn't require external data):
# Data --------------------------------------------------------------------
data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
group_icpts_truth[dat$group] +
group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(10, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]
# Reference model fit -----------------------------------------------------
suppressPackageStartupMessages(library(brms))
rfit_train <- brm(
y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + (1 | group),
data = dat_train,
chains = 1,
iter = 500,
refresh = 0,
seed = 1140350788
)
# projpred ----------------------------------------------------------------
options(projpred.extra_verbose = TRUE)
d_test_list <- list(
data = dat_test[, names(dat_test) != "y"],
offset = rep(0, nrow(dat_test)),
weights = rep(1, nrow(dat_test)),
y = dat_test[["y"]]
)
library(peakRAM)
vs_expr <- expression(
vs <- varsel(rfit_train,
d_test = d_test_list,
# nclusters = 50,
refit_prj = FALSE,
seed = 46782345)
)
## Previous state ---------------------------------------------------------
# With projpred at the state before this PR (commit:
# <https://github.com/stan-dev/projpred/commit/8e067a23d1e618b131e25c0965bb8c33cac9e9de>),
# run the following:
devtools::load_all()
peak_old <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_old, "[[", "Peak_RAM_Used_MiB"))
# 0% 25% 50% 75% 100%
# 119.20 127.15 128.95 157.30 160.30
## State of this PR -------------------------------------------------------
# For this, switch projpred to the state of this PR.
### Without gc() (default) ------------------------------------------------
# Now restart the R session, run all the code from the beginning up to
# (including) the `vs_expr` creation, then run the following:
devtools::load_all()
peak_new_nogc <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_new_nogc, "[[", "Peak_RAM_Used_MiB"))
# 0% 25% 50% 75% 100%
# 78.4 101.8 101.8 101.8 102.0
# Beside, some runtime benchmarking results for investigating how much runtime
# the gc() call adds on top (compare with results further below):
microbenchmark::microbenchmark(list = vs_expr, times = 20)
# Unit: seconds
# expr min lq mean median uq max neval
# vs <- varsel([...]) 36.37918 36.57373 36.84235 36.68676 36.99082 37.84213 20
### With gc() -------------------------------------------------------------
# Now restart the R session, run all the code from the beginning up to
# (including) the `vs_expr` creation, then run the following:
devtools::load_all()
options(projpred.run_gc = TRUE)
peak_new_gc <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_new_gc, "[[", "Peak_RAM_Used_MiB"))
# 0% 25% 50% 75% 100%
# 75.1 102.1 102.1 102.1 102.3
# Beside, some runtime benchmarking results for investigating how much runtime
# the gc() call adds on top (compare with results further above):
microbenchmark::microbenchmark(list = vs_expr, times = 20)
# Unit: seconds
# expr min lq mean median uq max neval
# vs <- varsel([...]) 39.28884 39.68015 39.8375 39.76543 40.15218 40.3328 20
While working on a reply to this Stan Discourse thread, I experienced a high peak memory usage during the forward search. This PR should alleviate this. An illustration using the code that brought me there (the example dataset may be found here; also note that the following code writes some files to the current working directory):
So with this PR, we need about 1 GB less peak memory than before ("old" state: commit https://github.com/stan-dev/projpred/commit/8e067a23d1e618b131e25c0965bb8c33cac9e9de).
NB: Initially, I tried that code with
ndraws = 1000
(andnclusters = NULL
), but that crashed my R session and thereby my whole machine. The above code ran through on my machine (with 16 GB of RAM).Another reprex (which doesn't require external data):