stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Memory issues in multilevel models #440

Closed fweber144 closed 11 months ago

fweber144 commented 1 year ago

When performing a variable selection for a multilevel model with a large number of observations and a large number of projected draws, we can quickly run out of memory. The reason is probably that projpred stores the whole submodel fit (which, for a multilevel model, is an lme4 fit) for each projected draw. Each of these submodel fits might not be so large, but the collection of all submodel fits may require more memory than available. A solution might be to reduce the size of the lme4 fits (I don't know if this is possible without breaking downstream code, in particular, without breaking the predict() method for such fits) or to add a custom predict() method for lme4 fits that requires less information (not the whole fit object). The latter idea might in fact be not that tedious to implement, considering that we have predict.subfit() and repair_re().

Illustration: On my machine (Linux) with 16 GB of RAM, the following reprex crashes R and eventually the whole machine (hence the CAUTION warning below and the wrapping in if (FALSE)):

# **CAUTION!**: This reprex is known to crash R and possibly the whole machine (at least under Linux).
if (FALSE) {
  # Data --------------------------------------------------------------------

  data("df_gaussian", package = "projpred")
  dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
  dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                  labels = paste0("gr", seq_len(8)))
  set.seed(457211)
  group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
  group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
  icpt <- -4.2
  dat$y <- icpt +
    group_icpts_truth[dat$group] +
    group_X1_truth[dat$group] * dat$X1
  dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
  # Make the dataset artificially long:
  dat <- do.call(rbind, replicate(60, dat, simplify = FALSE))
  # Split up into training and test (hold-out) dataset:
  idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
  dat_train <- dat[-idcs_test, , drop = FALSE]
  dat_test <- dat[idcs_test, , drop = FALSE]

  # Reference model fit -----------------------------------------------------

  suppressPackageStartupMessages(library(brms))
  options(mc.cores = min(parallel::detectCores(logical = FALSE), 4))
  rfit_train <- brm(
    y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + (1 | group),
    data = dat_train,
    refresh = 0,
    seed = 1140350788
  )

  # projpred ----------------------------------------------------------------

  suppressPackageStartupMessages(library(projpred))
  options(projpred.extra_verbose = TRUE)
  d_test_list <- list(
    data = dat_test[, names(dat_test) != "y"],
    offset = rep(0, nrow(dat_test)),
    weights = rep(1, nrow(dat_test)),
    y = dat_test[["y"]]
  )
  S_ref <- nrow(as.matrix(rfit_train))
  vs <- varsel(rfit_train,
               d_test = d_test_list,
               nclusters = 1,
               ndraws_pred = S_ref,
               seed = 46782345)
}
fweber144 commented 1 year ago

At least GAMMs are probably affected by this issue as well, possibly also GAMs.

fweber144 commented 1 year ago

Another idea which might be less tedious and less error-prone than implementing a custom predict() method (or reducing the size of the lme4 fits): It might be possible to solve this issue by combining get_sub_summaries() and get_submodls() into a single function (to avoid that the submodel fits from the increasingly complex submodels along the predictor ranking are stored at the same time). However, caution is needed with respect to get_submodls() in project() and in loo_varsel()'s validate_search = FALSE case.