stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Predictive performance gap / jumpy behavior at full size in Gaussian multilevel model #441

Open fweber144 opened 1 year ago

fweber144 commented 1 year ago

In this Stan Discourse reply (reference model: Gaussian, multilevel), we observed a gap in predictive performance between the submodels and the reference model when search_terms was NULL (with a "jump" towards the reference model's performance at the full model size), but not when forcing both group-level terms to be selected first. Reprex copied (and reduced to the relevant part) from that reply (the example dataset may be found here; also note that the reprex writes some files to the current working directory):

# Setup -------------------------------------------------------------------

library(brms)
library(projpred)

options(warn = 1)

options(mc.cores = min(parallel::detectCores(logical = FALSE), 4))
options(brms.backend = "cmdstanr")
options(brms.file_refit = "on_change")
options(cmdstanr_write_stan_file_dir = ".")

options(projpred.extra_verbose = TRUE)
options(projpred.check_conv = TRUE)

# Data --------------------------------------------------------------------

indicator_data <- read.csv("example_dataset.csv")

# Subsample `N` observations:
N <- 3000
set.seed(123)
indicator_data_N <- indicator_data[
  sample.int(nrow(indicator_data), size = N), , drop = FALSE
]
# Avoid `:` between grouping variables:
indicator_data_N[["iso_country_code_IA_village"]] <- paste(
  indicator_data_N$iso_country_code, indicator_data_N$village, sep = "_"
)

# Reference model fit -----------------------------------------------------

rfit <- brm(
  formula = log_tva ~ 1 + log_hh_size + education_cleaned + log_livestock_tlu +
    log_land_cultivated + off_farm_any + till_not_by_hand + external_labour +
    pesticide + debts_have + aidreceived + livestock_inputs_any +
    land_irrigated_any + norm_growing_period + log_min_travel_time +
    log_pop_dens + norm_gdl_country_shdi + (1 | iso_country_code) +
    (1 | iso_country_code_IA_village),
  data = indicator_data_N,
  prior = c(
    set_prior("normal(0, 1)", class = "b"),
    set_prior("normal(0, 1)", class = "sd"),
    set_prior("normal(0, 1)", class = "sigma"),
    set_prior("normal(0, 1)", class = "Intercept")
  ),
  family = gaussian(),
  file = "rfit",
  seed = 584356,
  refresh = 0
)

# projpred ----------------------------------------------------------------

# Run kfold() separately to save time later when running cv_varsel() multiple
# times:
set.seed(3424511)
refm_kfold <- kfold(rfit, K = 5, save_fits = TRUE)
cvfits_crr <- structure(
  list(fits = refm_kfold$fits[, "fit"]),
  K = length(refm_kfold$fits[, "fit"]),
  folds = sapply(seq_len(nrow(rfit$data)), function(ii) {
    which(sapply(refm_kfold$fits[, "omitted"], "%in%", x = ii))
  })
)
refmodel_obj <- get_refmodel(rfit, cvfits = cvfits_crr)
S_ref <- nrow(as.matrix(rfit))

# With the default of `search_terms = NULL`:
cvvs3 <- cv_varsel(
  refmodel_obj,
  cv_method = "kfold",
  seed = 1,
  nclusters = 3,
  control = lme4::lmerControl(
    optimizer = "Nelder_Mead"
  )
)
print(plot(cvvs3, ranking_nterms_max = NA))

Screenshot from 2023-08-18 11-53-38

print(plot(cv_proportions(cvvs3)))

Screenshot from 2023-08-18 11-53-48

# Forcing both group-level terms to be selected first:
get_search_terms_forced <- function(forced_terms, optional_terms) {
  forced_terms <- paste(forced_terms, collapse = " + ")
  return(c(forced_terms, paste0(forced_terms, " + ", optional_terms)))
}
optional_predictors <- c(
  "log_hh_size",
  "education_cleaned",
  "log_livestock_tlu",
  "log_land_cultivated",
  "off_farm_any",
  "till_not_by_hand",
  "external_labour",
  "pesticide",
  "debts_have",
  "aidreceived",
  "livestock_inputs_any",
  "land_irrigated_any",
  "norm_growing_period",
  "log_min_travel_time",
  "log_pop_dens",
  "norm_gdl_country_shdi"
)
forced_predictors <- c("(1 | iso_country_code)",
                       "(1 | iso_country_code_IA_village)")
search_terms_forcedGL <- get_search_terms_forced(forced_predictors,
                                                 optional_predictors)
cvvs4 <- cv_varsel(
  refmodel_obj,
  cv_method = "kfold",
  seed = 1,
  nclusters = 3,
  search_terms = search_terms_forcedGL,
  control = lme4::lmerControl(
    optimizer = "Nelder_Mead"
  )
)
print(plot(cvvs4, ranking_nterms_max = NA))

Screenshot from 2023-08-18 11-53-57

print(plot(cv_proportions(cvvs4)))

Screenshot from 2023-08-18 11-54-04

Details may be found in the Stan Discourse reply.

fweber144 commented 1 year ago

I think I was now able to reproduce this with a "standalone" example, i.e., without requiring the specific dataset used above:

# Data --------------------------------------------------------------------

# Generate outcome based on `df_gaussian` and several group-level terms:
data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$grpvar1 <- gl(n = 5, k = nrow(dat) %/% 5, length = nrow(dat),
                  labels = paste0("gr", seq_len(5)))
dat$grpvar2 <- gl(n = 6, k = nrow(dat) %/% 6, length = nrow(dat),
                  labels = paste0("agr", seq_len(6)))
dat$grpvar3 <- gl(n = 3, k = nrow(dat) %/% 3, length = nrow(dat),
                  labels = paste0("tgr", seq_len(3)))
set.seed(457211)
dat$grpvar2 <- sample(dat$grpvar2)
dat$grpvar3 <- sample(dat$grpvar3)
grpvar1_icpts_truth <- rnorm(nlevels(dat$grpvar1), sd = 6)
grpvar1_X1_truth <- rnorm(nlevels(dat$grpvar1), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  grpvar1_icpts_truth[dat$grpvar1] +
  grpvar1_X1_truth[dat$grpvar1] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)

# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) %/% 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

suppressPackageStartupMessages(library(rstanarm))
rfit <- stan_glmer(
  y ~ X1 + X2 + X3 + (1 | grpvar1) + (1 | grpvar1:grpvar2) + (1 | grpvar1:grpvar2:grpvar3),
  data = dat_train,
  chains = 1,
  iter = 500,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

devtools::load_all() # requires at least commit a8b25c178223e2cd607070962fa0732f9abc3d85
options(projpred.extra_verbose = TRUE)
refmodel_obj <- get_refmodel(rfit)
sep_char <- if (inherits(refmodel_obj$fit, "stanreg")) ":" else "_"
dat_test$`grpvar1:grpvar2` <- paste(dat_test$grpvar1,
                                    dat_test$grpvar2,
                                    sep = sep_char)
dat_test$`grpvar1:grpvar2:grpvar3` <- paste(dat_test$grpvar1,
                                            dat_test$grpvar2,
                                            dat_test$grpvar3,
                                            sep = sep_char)
d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
vs <- varsel(refmodel_obj,
             d_test = d_test_list,
             nclusters = 1,
             refit_prj = FALSE,
             seed = 46782345)
print(plot(vs))

vs

cvvs <- cv_varsel(refmodel_obj,
                  cv_method = "kfold",
                  K = 2,
                  nclusters = 1,
                  refit_prj = FALSE,
                  seed = 46782345)
print(plot(cvvs, ranking_nterms_max = NA))

cvvs

print(plot(cv_proportions(cvvs)))

cvvs_cvprops

# Forcing the group-level terms to be selected first:
get_search_terms_forced <- function(forced_terms, optional_terms) {
  forced_terms <- paste(forced_terms, collapse = " + ")
  return(c(forced_terms, paste0(forced_terms, " + ", optional_terms)))
}
forced_predictors <- c("(1 | grpvar1)", "(1 | grpvar1:grpvar2)",
                       "(1 | grpvar1:grpvar2:grpvar3)")
optional_predictors <- paste0("X", seq_len(3))
search_terms_forcedGL <- get_search_terms_forced(forced_predictors,
                                                 optional_predictors)
cvvs_forcedGL <- cv_varsel(refmodel_obj,
                           cv_method = "kfold",
                           K = 2,
                           nclusters = 1,
                           refit_prj = FALSE,
                           search_terms = search_terms_forcedGL,
                           seed = 46782345)
print(plot(cvvs_forcedGL, ranking_nterms_max = NA))

cvvs_forcedGL

print(plot(cv_proportions(cvvs_forcedGL)))

cvvs_forcedGL_cvprops