stan-dev / projpred

Projection predictive variable selection
https://mc-stan.org/projpred/
Other
110 stars 26 forks source link

Collect draw-wise projection warnings and check projection convergence #478

Closed fweber144 closed 8 months ago

fweber144 commented 8 months ago

This PR makes projpred catch messages and warnings from the draw-wise divergence minimizers and also check their convergence (as well as possible). Previously, projpred suppressed such messages and warnings and did not check convergence (PRs #259 and #444 started/modified the convergence checker, but it has remained a "hidden"—because unfinished—feature until now).

For deactivating these two features, global options projpred.warn_prj_drawwise and projpred.check_conv have been added (see the NEWS.md entries added here).

In my opinion, especially the convergence checker is a crucial feature, see, e.g., issue #323. The messages and warnings from the draw-wise divergence minimizers are intended as a help for the user to find out what might be going wrong without having to debug.

The convergence checks for additive models are probably still incomplete, even with this PR. I'll open a new issue for this.

Illustration:

# Setup -------------------------------------------------------------------

warn_length_orig <- options(warning.length = 8170)
devtools::load_all()

# glm_ridge(), glm_elnet() as submodel fitters ----------------------------

data("df_binom", package = "projpred")
dat <- data.frame(y = df_binom$y, df_binom$x)
fit_glm <- rstanarm::stan_glm(y ~ X1 + X2 + X3,
                              family = binomial(),
                              data = dat,
                              chains = 1,
                              iter = 500,
                              seed = 1140350788,
                              refresh = 0)

# Warning from glm_ridge():
prj <- project(fit_glm, predictor_terms = c("X1"), nclusters = 1, thresh = 0)

# Warning from glm_ridge() during the refits for performance evaluation:
vs <- varsel(fit_glm, method = "L1", nclusters_pred = 2, qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, method = "L1", nclusters_pred = 2, thresh_conv = 0)

# Warning from glm_ridge() during the forward search as well as during the
# refits for performance evaluation:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2, qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2, thresh_conv = 0)

# Warning from glm_ridge() during the forward search:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(qa_updates_max = 2))
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(thresh_conv = 0))

# Warning from glm_ridge() during the refits for performance evaluation:
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(), qa_updates_max = 2)
# Alternatively (this is a different warning, though):
vs <- varsel(fit_glm, nclusters = 1, nclusters_pred = 2,
             search_control = list(), thresh_conv = 0)

# Warning from glm_elnet() during the L1 search:
vs <- varsel(fit_glm, method = "L1", refit_prj = FALSE,
             search_control = list(thresh = 1e-330, nlambda = 1))

# MASS::polr() as submodel fitter -----------------------------------------

data("inhaler", package = "brms")
inhaler$rating <- as.factor(paste0("rtg", inhaler$rating))

fit_polr <- rstanarm::stan_polr(
  rating ~ period + carry + treat,
  data = inhaler,
  prior = rstanarm::R2(location = 0.5, what = "median"),
  chains = 1,
  iter = 500,
  seed = 1140350788,
  refresh = 0
)

# Non-convergence in MASS::polr():
prj <- project(fit_polr, predictor_terms = c("carry", "treat"), nclusters = 1,
               control = list(maxit = 1))

# Teardown ----------------------------------------------------------------

options(warn_length_orig)