stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Progress bar for `project()` #421

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

This adds a progress bar to project() (at least when using the built-in divergence minimizers). Illustration for the traditional (and latent) projection:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
df_gaussian <- df_gaussian[1:41, ]
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 8, k = floor(nrow(dat) / 8), length = nrow(dat),
                labels = paste0("gr", seq_len(8)))
dat$addgrp <- gl(n = 10, k = floor(nrow(dat) / 10), length = nrow(dat),
                 labels = paste0("agr", seq_len(10)))
dat$addgrp <- as.character(dat$addgrp)
set.seed(457211)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)

# Fit with brms -----------------------------------------------------------

suppressPackageStartupMessages(library(brms))
options(mc.cores = parallel::detectCores(logical = FALSE))
fit <- brm(y ~ X1 + X2 + X3 + X4 + X5 + (1 + X1 | group) + (1 | addgrp),
           data = dat,
           control = list(adapt_delta = 0.9),
           seed = 1140350788)

# projpred ----------------------------------------------------------------

devtools::load_all() # Use `library(projpred)` if installed.

refm <- get_refmodel(fit)

prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("(1 | group)", "X1", "X3", "(X1 | group)", "X2")
)

prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("(1 | group)", "X1", "X3", "(X1 | group)", "X2"),
  verbose = FALSE
)

options(projpred.verbose_project = FALSE)
prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("(1 | group)", "X1", "X3", "(X1 | group)", "X2")
)

options(projpred.verbose_project = NULL)
prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("(1 | group)", "X1", "X3", "(X1 | group)", "X2")
)

Illustration for the augmented-data projection:

# Data --------------------------------------------------------------------

data(inhaler, package = "brms")
inhaler$rating <- paste0("rtg", inhaler$rating)
inhaler$rating <- as.factor(inhaler$rating)

# Fit with rstanarm -------------------------------------------------------

library(rstanarm)
options(mc.cores = parallel::detectCores(logical = FALSE))
fit <- stan_polr(rating ~ period + carry + treat,
                 data = inhaler,
                 prior = R2(location = 0.5, what = "median"),
                 seed = 1140350788)

# projpred ----------------------------------------------------------------

devtools::load_all() # Use `library(projpred)` if installed.

refm <- get_refmodel(fit)

prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("carry", "treat")
)

prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("carry", "treat"),
  verbose = FALSE
)

options(projpred.verbose_project = FALSE)
prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("carry", "treat")
)

options(projpred.verbose_project = NULL)
prj <- project(
  refm,
  ndraws = 30,
  solution_terms = c("carry", "treat")
)