stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Fix `subfit` issues (mainly due to the lack of `xlev`) #403

Closed fweber144 closed 1 year ago

fweber144 commented 1 year ago

This resolves several bugs related to submodel fits of class subfit, see the message of commits fcf1408ca7dee7a15e69f3fe3d77241866e5f50d and d08920a345d49112567e00bdda36f17ed8d26a88. The main problem was that in predict.subfit(), argument xlev of model.matrix() was not used. Unfortunately, deriving these xlevels at fitting time is not really straightforward, which is the reason why this PR adds a bunch of code for deriving these at the beginning of fit_glm_ridge_callback() and search_L1().

All of the issues resolved by commit fcf1408ca7dee7a15e69f3fe3d77241866e5f50d can be reproduced using the following code:

options(mc.cores = parallel::detectCores(logical = FALSE))
N <- 41L
K <- 5L
K_fac <- 4L
set.seed(457324)
dat <- data.frame(
  y = rnorm(N),
  xcat = gl(n = K, k = floor(N / K), length = N,
            labels = paste0("gr", seq_len(K))),
  xfac = sample(gl(n = K_fac, k = floor(N / K_fac), length = N,
                   labels = paste0("fgr", seq_len(K_fac)))),
  xlog = sample(rep_len(c(TRUE, FALSE), length.out = N))
)
levels(dat$xfac) <- c(levels(dat$xfac),
                      paste0("fgr", (K_fac + 1L):(K_fac + 2L)))
dat$xcat <- as.character(dat$xcat)

library(rstanarm)
rfit <- stan_glm(y ~ xcat + xfac + xlog,
                 data = dat,
                 seed = 1140350788,
                 chains = 1, iter = 500,
                 refresh = 0)

library(projpred)

# debug(fit_glm_ridge_callback)
prj <- project(rfit, solution_terms = c("xcat", "xfac", "xlog"),
               nclusters = 1)

dat_new <- data.frame(
  y = c(-1.8, 0.1),
  xcat = rep(paste0("gr", 2), 2),
  xfac = factor(rep(paste0("fgr", 2), 2),
                levels = c(rev(levels(dat$xfac)), paste0("fgr", K_fac + 3L))),
  xlog = rep(TRUE, 2)
)
# debug(predict.subfit)
prjlp <- proj_linpred(prj, newdata = dat_new)

# debug(search_L1)
cvvs <- cv_varsel(rfit,
                  nclusters = 1,
                  nclusters_pred = 1,
                  seed = 46782345)
cvvs <- cv_varsel(rfit,
                  nclusters = 1,
                  nclusters_pred = 1,
                  refit_prj = FALSE,
                  seed = 46782345)

(All of the calls following a commented debug() call failed before this PR and succeed afterwards.)

For the issue resolved by commit d08920a345d49112567e00bdda36f17ed8d26a88, see the corresponding commit message for reproducibility instructions.