stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Reduce peak memory usage in forward search #442

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

While working on a reply to this Stan Discourse thread, I experienced a high peak memory usage during the forward search. This PR should alleviate this. An illustration using the code that brought me there (the example dataset may be found here; also note that the following code writes some files to the current working directory):

# Setup -------------------------------------------------------------------

library(brms)

options(warn = 1)

options(mc.cores = min(parallel::detectCores(logical = FALSE), 4))
options(brms.backend = "cmdstanr")
options(brms.file_refit = "on_change")
options(cmdstanr_write_stan_file_dir = ".")

options(projpred.extra_verbose = TRUE)
options(projpred.check_conv = TRUE)

# Data --------------------------------------------------------------------

indicator_data <- read.csv("example_dataset.csv")

# Subsample `N` observations:
N <- 3000
set.seed(123)
indicator_data_N <- indicator_data[
  sample.int(nrow(indicator_data), size = N), , drop = FALSE
]
# Avoid `:` between grouping variables:
indicator_data_N[["iso_country_code_IA_village"]] <- paste(
  indicator_data_N$iso_country_code, indicator_data_N$village, sep = "_"
)

# Reference model fit -----------------------------------------------------

rfit <- brm(
  formula = log_tva ~ 1 + log_hh_size + education_cleaned + log_livestock_tlu +
    log_land_cultivated + off_farm_any + till_not_by_hand + external_labour +
    pesticide + debts_have + aidreceived + livestock_inputs_any +
    land_irrigated_any + norm_growing_period + log_min_travel_time +
    log_pop_dens + norm_gdl_country_shdi + (1 | iso_country_code) +
    (1 | iso_country_code_IA_village),
  data = indicator_data_N,
  prior = c(
    set_prior("normal(0, 1)", class = "b"),
    set_prior("normal(0, 1)", class = "sd"),
    set_prior("normal(0, 1)", class = "sigma"),
    set_prior("normal(0, 1)", class = "Intercept")
  ),
  family = gaussian(),
  file = "rfit",
  seed = 584356,
  refresh = 0
)

# projpred ----------------------------------------------------------------

devtools::load_all("<path_to_projpred_at_state_of_this_PR>")
refmodel_obj <- get_refmodel(rfit)

search_forward_old <- function(p_ref, refmodel, nterms_max, verbose = TRUE, opt,
                               search_terms, ...) {
  nterms_max_with_icpt <- nterms_max + 1L
  iq <- ceiling(quantile(seq_len(nterms_max_with_icpt), 1:10 / 10))
  if (is.null(search_terms)) {
    stop("Did not expect `search_terms` to be `NULL`. Please report this.")
  }

  chosen <- character()
  outdmins <- c()

  for (size in seq_len(nterms_max_with_icpt)) {
    cands <- select_possible_terms_size(chosen, search_terms, size = size)
    if (is.null(cands))
      next
    full_cands <- lapply(cands, function(cand) c(chosen, cand))
    submodls <- lapply(full_cands, get_submodl_prj, p_ref = p_ref,
                       refmodel = refmodel, regul = opt$regul, ...)

    ## select best candidate
    imin <- which.min(sapply(submodls, "[[", "ce"))
    chosen <- c(chosen, cands[imin])

    ## append `outdmin`
    outdmins <- c(outdmins, list(submodls[[imin]]$outdmin))

    ct_chosen <- count_terms_chosen(chosen)
    if (verbose && ct_chosen %in% iq) {
      vtxt <- paste(names(iq)[max(which(ct_chosen == iq))], "of terms selected")
      if (getOption("projpred.extra_verbose", FALSE)) {
        vtxt <- paste0(vtxt, ": ", paste(chosen, collapse = " + "))
      }
      verb_out(vtxt)
    }
  }

  return(nlist(solution_terms = setdiff(chosen, "1"), outdmins))
}

library(peakRAM)
debug(select)
vs <- varsel(
  refmodel_obj,
  seed = 1,
  # ndraws = 400, # set `nclusters` to `NULL` to use this
  nclusters = 50,
  refit_prj = FALSE,
  nterms_max = 5,
  control = lme4::lmerControl(
    optimizer = "Nelder_Mead"
  )
)
### Now debug select() until the point right before search_forward() is called. Then run:
# peakRAM(search_path <- search_forward(p_sel, refmodel, nterms_max, verbose, opt,
#                                       search_terms = search_terms, ...),
#         search_path_old <- search_forward_old(p_sel, refmodel, nterms_max, verbose, opt,
#                                               search_terms = search_terms, ...))
###
## Result:
#                                Function_Call Elapsed_Time_sec  Total_RAM_Used_MiB Peak_RAM_Used_MiB
# 1         search_path<-search_forward([...])          285.503               192.2             969.9
# 2 search_path_old<-search_forward_old([...])          286.368               191.4            2068.1
##

So with this PR, we need about 1 GB less peak memory than before ("old" state: commit https://github.com/stan-dev/projpred/commit/8e067a23d1e618b131e25c0965bb8c33cac9e9de).

NB: Initially, I tried that code with ndraws = 1000 (and nclusters = NULL), but that crashed my R session and thereby my whole machine. The above code ran through on my machine (with 16 GB of RAM).

Another reprex (which doesn't require external data):

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(10, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

suppressPackageStartupMessages(library(brms))
rfit_train <- brm(
  y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + (1 | group),
  data = dat_train,
  chains = 1,
  iter = 500,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

options(projpred.extra_verbose = TRUE)
d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
library(peakRAM)
vs_expr <- expression(
  vs <- varsel(rfit_train,
               d_test = d_test_list,
               # nclusters = 50,
               refit_prj = FALSE,
               seed = 46782345)
)

## Previous state ---------------------------------------------------------

# With projpred at the state before this PR (commit:
# <https://github.com/stan-dev/projpred/commit/8e067a23d1e618b131e25c0965bb8c33cac9e9de>),
# run the following:
devtools::load_all()
peak_old <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_old, "[[", "Peak_RAM_Used_MiB"))
#     0%    25%    50%    75%   100%
# 119.20 127.15 128.95 157.30 160.30

## State of this PR -------------------------------------------------------

# For this, switch projpred to the state of this PR.

### Without gc() (default) ------------------------------------------------

# Now restart the R session, run all the code from the beginning up to
# (including) the `vs_expr` creation, then run the following:
devtools::load_all()
peak_new_nogc <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_new_nogc, "[[", "Peak_RAM_Used_MiB"))
#    0%   25%   50%   75%  100%
#  78.4 101.8 101.8 101.8 102.0
# Beside, some runtime benchmarking results for investigating how much runtime
# the gc() call adds on top (compare with results further below):
microbenchmark::microbenchmark(list = vs_expr, times = 20)
# Unit: seconds
#                expr      min       lq     mean   median       uq      max neval
# vs <- varsel([...]) 36.37918 36.57373 36.84235 36.68676 36.99082 37.84213    20

### With gc() -------------------------------------------------------------

# Now restart the R session, run all the code from the beginning up to
# (including) the `vs_expr` creation, then run the following:
devtools::load_all()
options(projpred.run_gc = TRUE)
peak_new_gc <- replicate(20, peakRAM(eval(vs_expr)), simplify = FALSE)
quantile(sapply(peak_new_gc, "[[", "Peak_RAM_Used_MiB"))
#   0%   25%   50%   75%  100%
# 75.1 102.1 102.1 102.1 102.3
# Beside, some runtime benchmarking results for investigating how much runtime
# the gc() call adds on top (compare with results further above):
microbenchmark::microbenchmark(list = vs_expr, times = 20)
# Unit: seconds
#                expr      min       lq    mean   median       uq     max neval
# vs <- varsel([...]) 39.28884 39.68015 39.8375 39.76543 40.15218 40.3328    20