stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Simplify the required structure for `cvfits` #456

Closed fweber144 closed 9 months ago

fweber144 commented 9 months ago

This simplifies the required structure for the object passed to argument cvfits of init_refmodel() (see the NEWS.md entry added here for details). The reason for this change is that loo::kfold() output can't be used straightforwardly anyway:

data("df_gaussian", package = "projpred")
dat <- data.frame(y = df_gaussian$y, df_gaussian$x)
n_strat <- 4L
dat$strat_fac <- gl(n = n_strat, k = floor(nrow(dat) / n_strat),
                    length = nrow(dat), labels = paste0("gr", seq_len(n_strat)))
library(rstanarm)
refm_fit <- stan_glm(y ~ X1 + X2 + X3 + X4 + X5,
                     data = dat,
                     chains = 1,
                     iter = 500,
                     seed = 1140350788,
                     refresh = 0)
set.seed(3424511)
refm_kfold <- kfold(
  refm_fit,
  folds = loo::kfold_split_stratified(K = 10, x = dat$strat_fac),
  save_fits = TRUE,
  cores = 1
)
cvfits_crr <- structure(
  list(fits = refm_kfold$fits[, "fit"]),
  folds = sapply(seq_len(nrow(dat)), function(ii) {
    which(sapply(refm_kfold$fits[, "omitted"], "%in%", x = ii))
  })
)
length(refm_kfold$fits)
## --> 20
length(cvfits_crr$fits)
## --> 10
lapply(refm_kfold$fits, class)
## --> First 10 are `stanreg`s, last 10 are `integer` vectors.

so having the K reference model refits in a sub-list called fits is not necessary and only complicates things.