easystats / performance

:muscle: Models' quality and performance metrics (R2, ICC, LOO, AIC, BF, ...)
https://easystats.github.io/performance/
GNU General Public License v3.0
1k stars 87 forks source link

Reproduce manually rstan's pp_check #202

Closed DominiqueMakowski closed 1 year ago

DominiqueMakowski commented 3 years ago

Opening this for myself, will close soon

m <- rstanarm::stan_glm(mpg ~ am, data=mtcars, refresh=0)
rstanarm::pp_check(m)

Created on 2021-02-22 by the reprex package (v0.3.0)

strengejacke commented 3 years ago

https://easystats.github.io/performance/reference/pp_check.html https://easystats.github.io/see/articles/performance.html#posterior-predictive-checks

DominiqueMakowski commented 3 years ago

Haha I was looking for that in 'see' !

DominiqueMakowski commented 3 years ago

Wait no but actually I want to replicate the rstanarm thing for Bayesian models:

m <- rstanarm::stan_glm(mpg ~ am, data=mtcars, refresh=0)
performance::posterior_predictive_check(m)
#> Warning in rnorm(ntot, sd = sqrt(vars)): NAs produced
#> Error in if (min(replicated) > min(original)) {: missing value where TRUE/FALSE needed

Created on 2021-02-23 by the reprex package (v0.3.0)

@strengejacke, should we have two versions of it, namely pp_check() and check_prediction_distribution()

DominiqueMakowski commented 3 years ago
library(ggplot2)
library(bayestestR)
#> Note: The default CI width (currently `ci=0.89`) might change in future versions (see https://github.com/easystats/bayestestR/discussions/250). To prevent any issues, please set it explicitly when using bayestestR functions, via the 'ci' argument.

model <- rstanarm::stan_glm(mpg ~ am, data=mtcars, refresh=0)

pred <- as.data.frame(insight::get_predicted(model, ci_type = "prediction"))
data_density <- estimate_density(pred, precision = 100)

ggplot(data_density, aes(x = x, y = y)) +
  geom_line(aes(group = Parameter), alpha=0.01) +
  geom_line(data= estimate_density(insight::get_response(model)), color="red", size=2) + 
  scale_y_continuous(expand = c(0, 0)) +
  see::theme_modern()

Created on 2021-02-23 by the reprex package (v0.3.0)

strengejacke commented 3 years ago

Wait no but actually I want to replicate the rstanarm thing for Bayesian models:

Yeah, but I did this for freq. models, because for Bayesian models it's already available?

For Bayesian models, pp_check is from stan, but the other one would return our version (rather than the one using the default aesthetic of rstan)

What would be the difference?

DominiqueMakowski commented 3 years ago

our version is more beautiful? ๐Ÿ˜ + we can add more info like a title? + it gives a unified output for freq and bayesian (same colors, etc)

strengejacke commented 3 years ago

ok, got it.

DominiqueMakowski commented 3 years ago

Also, our pp_check is currently not defined for brmsfit objects :(, and for complex models the brms/rstantools version fails:

reprex ``` r library(brms) #> Loading required package: Rcpp #> Loading 'brms' package (version 2.15.0). Useful instructions #> can be found by typing help('brms'). A more detailed introduction #> to the package is available through vignette('brms_overview'). #> #> Attaching package: 'brms' #> The following object is masked from 'package:stats': #> #> ar # Stan family ------------------------------------------------------------- rtmix <- brms::custom_family( "rtmix", dpars = c("mu", "sigma", "mix", "shiftprop"), # Those will be estimated links = c("identity", "log", "logit", "logit"), type = "real", lb = c(NA, 0, 0, 0), # bounds for the parameters ub = c(NA, NA, 1, 1), vars = c("vreal1[n]", "vreal2[n]") # Data for max_shift and upper (known) ) # Stan model ------------------------------------------------------------------- stan_functions <- brms::stanvar(block = "functions", scode = " real rtmix_lpdf(real y, real mu, real sigma, real mix, real shiftprop, real max_shift, real upper) { real shift = shiftprop * max_shift; if(y <= shift) { // Could only be created by the contamination return log(mix) + uniform_lpdf(y | 0, upper); } else if(y >= upper) { // Could only come from the lognormal return log1m(mix) + lognormal_lpdf(y - shift | mu, sigma); } else { // Actually mixing real lognormal_llh = lognormal_lpdf(y - shift | mu, sigma); real uniform_llh = uniform_lpdf(y | 0, upper); return log_mix(mix, uniform_llh, lognormal_llh); } } ") + brms::stanvar(block = "functions", scode = " real rtmix_lcdf(real y, real mu, real sigma, real mix, real shiftprop, real max_shift, real upper) { real shift = shiftprop * max_shift; if(y <= shift) { return log(mix) + uniform_lcdf(y | 0, upper); } else if(y >= upper) { // The whole uniform part is below, so the mixture part is log(1) = 0 return log_mix(mix, 0, lognormal_lcdf(y - shift | mu, sigma)); } else { real lognormal_llh = lognormal_lcdf(y - shift | mu, sigma); real uniform_llh = uniform_lcdf(y | 0, upper); return log_mix(mix, uniform_llh, lognormal_llh); } } real rtmix_lccdf(real y, real mu, real sigma, real mix, real shiftprop, real max_shift, real upper) { real shift = shiftprop * max_shift; if(y <= shift) { // The whole lognormal part is above, so the mixture part is log(1) = 0 return log_mix(mix, uniform_lccdf(y | 0, upper), 0); } else if(y >= upper) { return log1m(mix) + lognormal_lccdf(y - shift | mu, sigma); } else { real lognormal_llh = lognormal_lccdf(y - shift | mu, sigma); real uniform_llh = uniform_lccdf(y | 0, upper); return log_mix(mix, uniform_llh, lognormal_llh); } } ") # Model components -------------------------------------------------------- posterior_predict_rtmix <- function(i, prep, ...) { if ((!is.null(prep$data$lb) && prep$data$lb[i] > 0) || (!is.null(prep$data$ub) && prep$data$ub[i] < Inf)) { stop("Predictions for truncated distributions not supported") } mu <- brms:::get_dpar(prep, "mu", i = i) sigma <- brms:::get_dpar(prep, "sigma", i = i) mix <- brms:::get_dpar(prep, "mix", i = i) shiftprop <- brms:::get_dpar(prep, "shiftprop", i = i) max_shift <- prep$data$vreal1[i] upper <- prep$data$vreal2[i] shift <- shiftprop * max_shift rtmix(prep$nsamples, meanlog = mu, sdlog = sigma, mix = mix, shift = shift, upper = upper ) } # Needed for numerical stability # from http://tr.im/hH5A logsumexp <- function(x) { y <- max(x) y + log(sum(exp(x - y))) } rtmix_lpdf <- function(y, meanlog, sdlog, mix, shift, upper) { unif_llh <- dunif(y, min = 0, max = upper, log = TRUE) lognormal_llh <- dlnorm(y - shift, meanlog = meanlog, sdlog = sdlog, log = TRUE) - plnorm(upper - shift, meanlog = meanlog, sdlog = sdlog, log.p = TRUE) # Computing logsumexp(log(mix) + unif_llh, log1p(-mix) + lognormal_llh) # but vectorized llh_matrix <- array(NA_real_, dim = c(2, max(length(unif_llh), length(lognormal_llh)))) llh_matrix[1, ] <- log(mix) + unif_llh llh_matrix[2, ] <- log1p(-mix) + lognormal_llh apply(llh_matrix, MARGIN = 2, FUN = logsumexp) } log_lik_rtmix <- function(i, draws) { mu <- brms:::get_dpar(draws, "mu", i = i) sigma <- brms:::get_dpar(draws, "sigma", i = i) mix <- brms:::get_dpar(draws, "mix", i = i) shiftprop <- brms:::get_dpar(draws, "shiftprop", i = i) max_shift <- draws$data$vreal1[i] upper <- draws$data$vreal2[i] shift <- shiftprop * max_shift y <- draws$data$Y[i] rtmix_lpdf(y, meanlog = mu, sdlog = sigma, mix = mix, shift = shift, upper = upper ) } # Test -------------------------------------------------------------------- library(dplyr) #> #> Attaching package: 'dplyr' #> The following objects are masked from 'package:stats': #> #> filter, lag #> The following objects are masked from 'package:base': #> #> intersect, setdiff, setequal, union library(ggplot2) # Generate Data generate_RT <- function(n, meanlog, sdlog, mix, shift, upper) { ifelse(runif(n) < mix, runif(n, 0, upper), shift + rlnorm(n, meanlog = meanlog, sdlog = sdlog) ) } # Parameters n <- 3 n_obs <- 500 max_shift <- runif(n, 0.25, 0.5) shift <- runif(n) * max_shift upper <- runif(n, 8, 10) mix <- runif(n, 0, 0.2) intercept <- runif(n, 0.2, 1) beta <- abs(rnorm(n, 0.5, 0.5)) sigma <- abs(rnorm(n, 0.5, 0.2)) df <- data.frame() for(i in 1:n){ X <- rnorm(n_obs) mu <- rep(intercept[i], n_obs) + beta[i] * X df <- rbind(df, data.frame( RT = generate_RT(n = n_obs, meanlog = mu, sdlog = sigma[i], mix = mix[i], shift = shift[i], upper = upper[i]), x = X, max_shift = max_shift[i], upper = upper[i], participant = as.factor(i))) df$RT[df$RT > 10] <- NA } ggplot(df, aes(x = RT, color = participant)) + geom_density() #> Warning: Removed 56 rows containing non-finite values (stat_density). ``` ![](https://i.imgur.com/ehvv17X.png) ``` r f <- brmsformula( # RT | vreal(max_shift, upper) ~ x RT | vreal(max_shift, upper) ~ x + (1|participant) # beta ~ 1 + (1|G|Participant), # beta is the "tau" # sigma ~ 1 + (1|G|Participant) ) model <- brms::brm(f, data = df, family = rtmix, stanvars = stan_functions, refresh = 0, iter = 500, prior = c(brms::prior(beta(1, 5), class = "mix"))) #> Warning: Rows containing NAs were excluded from the model. #> Compiling Stan program... #> Start sampling #> Warning: There were 8 divergent transitions after warmup. See #> http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup #> to find out why this is a problem and how to eliminate them. #> Warning: Examine the pairs() plot to diagnose sampling problems #> Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable. #> Running the chains for more iterations may help. See #> http://mc-stan.org/misc/warnings.html#bulk-ess #> Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable. #> Running the chains for more iterations may help. See #> http://mc-stan.org/misc/warnings.html#tail-ess brms::pp_check(model) #> Using 10 posterior samples for ppc type 'dens_overlay' by default. #> Error in rtmix(prep$nsamples, meanlog = mu, sdlog = sigma, mix = mix, : could not find function "rtmix" performance::pp_check(model) #> Error in UseMethod("pp_check"): no applicable method for 'pp_check' applied to an object of class "brmsfit" ``` Created on 2021-04-03 by the [reprex package](https://reprex.tidyverse.org) (v1.0.0)
strengejacke commented 3 years ago

of course it's not defined for stan models, as those methods are in rstan/rstanrm/brms. I didn't want to overwrite them.

strengejacke commented 3 years ago

will close soon

๐Ÿคจ

DominiqueMakowski commented 3 years ago

oops, i got a bit ahead of myself back then

bwiernik commented 3 years ago

our version is more beautiful? ๐Ÿ˜ + we can add more info like a title? + it gives a unified output for freq and bayesian (same colors, etc)

Could we change to something other than red for the MLE line? Blue or green are more colorblind accessible on top of the black posterior lines.

of course it's not defined for stan models, as those methods are in rstan/rstanrm/brms. I didn't want to overwrite them.

One option might be to define a see print/plot method for them to apply consistent theming without overwriting the core method.

strengejacke commented 3 years ago

@bwiernik That was just a quick sketch from @DominiqueMakowski. Our current implementation (only for frequentist) already has a blue line on gray lines: https://easystats.github.io/see/articles/performance.html#posterior-predictive-checks

strengejacke commented 1 year ago

Closing in favor of #477