stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Interactions between grouping variables #445

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

This fixes a bug for rstanarm (and custom) multilevel reference models with interactions (: syntax) between grouping variables, caused by missing columns in the reference model's data.frame (for brms reference models, this was already done correctly).

Illustration:

# Data --------------------------------------------------------------------

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
dat$group <- gl(n = 5, k = nrow(dat) %/% 5, length = nrow(dat),
                labels = paste0("gr", seq_len(5)))
dat$addgrp <- gl(n = 6, k = nrow(dat) %/% 6, length = nrow(dat),
                 labels = paste0("agr", seq_len(6)))
dat$thirdgrp <- gl(n = 3, k = nrow(dat) %/% 3, length = nrow(dat),
                   labels = paste0("tgr", seq_len(3)))
set.seed(457211)
dat$addgrp <- sample(dat$addgrp)
dat$thirdgrp <- sample(dat$thirdgrp)
group_icpts_truth <- rnorm(nlevels(dat$group), sd = 6)
group_X1_truth <- rnorm(nlevels(dat$group), sd = 6)
icpt <- -4.2
dat$y <- icpt +
  group_icpts_truth[dat$group] +
  group_X1_truth[dat$group] * dat$X1
dat$y <- rnorm(nrow(dat), mean = dat$y, sd = 4)
# Split up into training and test (hold-out) dataset:
idcs_test <- sample.int(nrow(dat), size = nrow(dat) %/% 3)
dat_train <- dat[-idcs_test, , drop = FALSE]
dat_test <- dat[idcs_test, , drop = FALSE]

# Reference model fit -----------------------------------------------------

suppressPackageStartupMessages(library(rstanarm))
rfit <- stan_glmer(
  y ~ X1 + X2 + X3 + (1 | group) + (1 | group:addgrp) + (1 | group:addgrp:thirdgrp),
  data = dat_train,
  chains = 1,
  iter = 500,
  refresh = 0,
  seed = 1140350788
)

# projpred ----------------------------------------------------------------

devtools::load_all()

refmodel_obj <- get_refmodel(rfit)

prj <- project(refmodel_obj,
               solution_terms = c("(1 | group)", "(1 | group:addgrp)"),
               nclusters = 1,
               seed = 457324)

dat_test$`group:addgrp` <- paste(dat_test$group,
                                 dat_test$addgrp,
                                 sep = ":")
dat_test$`group:addgrp:thirdgrp` <- paste(dat_test$group,
                                          dat_test$addgrp,
                                          dat_test$thirdgrp,
                                          sep = ":")
d_test_list <- list(
  data = dat_test[, names(dat_test) != "y"],
  offset = rep(0, nrow(dat_test)),
  weights = rep(1, nrow(dat_test)),
  y = dat_test[["y"]]
)
vs <- varsel(refmodel_obj,
             d_test = d_test_list,
             nclusters = 1,
             refit_prj = FALSE,
             seed = 46782345)

cvvs <- cv_varsel(refmodel_obj,
                  cv_method = "kfold",
                  nclusters = 1,
                  refit_prj = FALSE,
                  seed = 46782345)

These calls failed before this PR and succeed now.

fweber144 commented 1 year ago

NB: This also catches : on the right-hand side of the bar (|) in group-level terms in repair_re() methods and then throws an adapted error message (although this error message should not occur, given the automatic creation of data columns implemented in this PR).