stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 25 forks source link

Refine runtime estimation for forward search #459

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

This refines the estimation of the runtime of the forward search remaining after the projection onto the intercept-only submodel, in particular allowing for an interval estimate in case of multilevel and/or additive ("smooth") terms.

The factors used for scaling up the runtime estimate (coming from the intercept-only projection) were derived empirically as follows:

# Source for the data-generating mechanism and the reference model: Example
# section of `?rstanarm::stan_gamm4`.
set.seed(7456)
dat <- mgcv::gamSim(1, n = 200, scale = 2)
dat$fac <- fac <- as.factor(sample(1:20, 200, replace = TRUE))
dat$y <- dat$y + model.matrix(~ fac - 1) %*% rnorm(20) * 0.5
rfit <- rstanarm::stan_gamm4(
  y ~ s(x0) + x1 + s(x2),
  random = ~ (1 | fac),
  data = dat,
  cores = 4,
  seed = 1140350788,
  adapt_delta = 0.99,
  refresh = 0
)

# With projpred at commit cc0d3064:
devtools::load_all(".")
set.seed(8234467)
refm <- get_refmodel(rfit)
refd <- get_refdist(refm, nclusters = 20)

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = character(),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#     min       lq     mean  median       uq      max neval
# 13.9557 14.16054 17.08022 14.3025 14.82752 153.7715   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#      min       lq     mean   median       uq      max neval
# 17.81272 18.35631 19.42337 18.52197 18.82939 61.19784   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "(1 | fac)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#      min       lq     mean   median      uq      max neval
# 473.7028 489.0631 502.0515 497.7506 503.722 761.9663   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "s(x0)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: milliseconds
#     min       lq     mean   median       uq      max neval
# 166.975 170.1548 179.7964 177.2802 185.0535 228.0465   100

prj_expr <- expression(
  prj_out <- get_submodl_prj(
    solution_terms = c("x1", "(1 | fac)", "s(x0)"),
    p_ref = refd, refmodel = refm, regul = 1e-04
  )
)
microbenchmark::microbenchmark(list = prj_expr, times = 100)
# Unit: seconds
#      min       lq     mean   median       uq     max neval
# 1.017806 1.043563 1.079292 1.061689 1.084159 1.35676   100
###

From these microbenchmark results, we obtain the following factors (I should have assigned the microbenchmark::microbenchmark() outputs to different objects instead of working with the hard-coded times here, but I was too lazy to re-run):

bm_empty <- c(13.9557, 14.16054, 17.08022, 14.3025, 14.82752, 153.7715)
bm_glm <- c(17.81272, 18.35631, 19.42337, 18.52197, 18.82939, 61.19784)
bm_glmm <- c(473.7028, 489.0631, 502.0515, 497.7506, 503.722, 761.9663)
bm_gam <- c(166.975, 170.1548, 179.7964, 177.2802, 185.0535, 228.0465)
bm_gamm <- 1e3 * c(1.017806, 1.043563, 1.079292, 1.061689, 1.084159, 1.35676)

bm_glm / bm_empty
# [1] 1.2763760 1.2963001 1.1371850 1.2950163 1.2698948 0.3979791
## --> ca. 1.3 from intercept-only to GLM submodel

bm_glmm / bm_glm
# [1] 26.59351 26.64278 25.84781 26.87352 26.75190 12.45087
## --> ca. 26.9 from GLM to GLMM submodel

bm_gam / bm_glm
# [1] 9.373919 9.269554 9.256705 9.571347 9.827907 3.726382
## --> ca. 9.8 from GLM to GAM submodel

bm_gamm / bm_glm
# [1] 57.13928 56.85037 55.56667 57.32052 57.57802 22.17006
## --> ca. 57.6 from GLM to GAMM submodel
fweber144 commented 1 year ago

On my machine, the following example causes the runtime message to be displayed:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(6, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

rfit_train <- rstanarm::stan_glmer(
  y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10 + X11 + X12 + X13 + X14 +
    X15 + X16 + X17 + X18 + X19 + X20 + (1 | group),
  data = dat_train,
  cores = 4,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

# With projpred at commit c7b1d2d7:
devtools::load_all(".")
options(projpred.extra_verbose = TRUE)

d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
Sys.time()
vs <- varsel(rfit_train,
             d_test = d_test_list,
             refit_prj = FALSE,
             seed = 46782345)
Sys.time()

And the following example doesn't:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Make the dataset artificially long:
dat <- do.call(rbind, replicate(6, dat, simplify = FALSE))
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) / 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

rfit_train <- rstanarm::stan_glm(
  y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10 + X11 + X12 + X13 + X14 +
    X15 + X16 + X17 + X18 + X19 + X20,
  data = dat_train,
  cores = 4,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

# With projpred at commit c7b1d2d7:
devtools::load_all(".")
options(projpred.extra_verbose = TRUE)

d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
Sys.time()
vs <- varsel(rfit_train,
             d_test = d_test_list,
             refit_prj = FALSE,
             seed = 46782345)
Sys.time()