As one of the next features in RBesT I am intending to allow the use of mixture priors as defined in RBesT in the context of brms models. Below is some prototype for how this can look like in practice. This is a very first version (requiring RBesT 1.7.0) and feedback from users would be very welcome. I'd hope that the usage is clear from the examples given (for univariate normal and multivariate normal). I have iterated a bit of the design, but would like to hear from users if the syntax is easy and straightforward to use:
small example demonstrating a possible workflow on how one can use
the posterior of one brms fit as prior for the next one. Use as an
example the equivalence of MAP and MAC. MAP: fit hierarchical
model, approximate the prior for a new unit and then fit the data
item of interest. MAC: do a joint fit resulting in the same
inference as compared to MAP.
library(RBesT)
#> This is RBesT version 1.7.0 (released 2023-07-20, git-sha d5f8521)
library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.19.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#>
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#>
#> ar
library(posterior)
#> This is posterior version 1.4.1
#>
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#>
#> mad, sd, var
#> The following objects are masked from 'package:base':
#>
#> %in%, match
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(knitr)
library(mvtnorm)
library(bayesplot)
#> Warning: package 'bayesplot' was built under R version 4.1.1
#> This is bayesplot version 1.10.0
#> - Online documentation and vignettes at mc-stan.org/bayesplot
#> - bayesplot theme set to bayesplot::theme_default()
#> * Does _not_ affect other ggplot2 plots
#> * See ?bayesplot_theme_set for details on theme setting
#>
#> Attaching package: 'bayesplot'
#> The following object is masked from 'package:posterior':
#>
#> rhat
#> The following object is masked from 'package:brms':
#>
#> rhat
suppressPackageStartupMessages(library(here))
# instruct brms to use cmdstanr as backend and cache all Stan binaries
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=here("brms-cache"))
# create cache directory if not yet available
dir.create(here("brms-cache"), FALSE)
options(mc.sores=4)
####### Utility functions which will be RBesT internal #################
mix2brms <- function(mix, name, verbose=FALSE) UseMethod("mix2brms")
mix2brms.mvnormMix <- function(mix, name, verbose=FALSE) {
Nc <- ncol(mix)
p <- RBesT:::mvnormdim(mix[-1,1])
Sigma <- array(NA, dim=c(Nc, p, p))
for(i in 1:Nc) {
Rho_c <- diag(nrow=p)
Rho_c[lower.tri(Rho_c)] <- mix[(1+2*p+1):nrow(mix),i,drop=FALSE]
Rho_c[upper.tri(Rho_c)] <- t(Rho_c)[upper.tri(Rho_c)]
s <- mix[(1+p+1):(1+p+p),i,drop=TRUE]
Sigma[i,,] <- diag(s, nrow=p) %*% Rho_c %*% diag(s, nrow=p)
}
prefix <- paste0(name, "_")
mvprior <- stanvar(Nc, glue("{prefix}Nc")) +
stanvar(p, glue("{prefix}p")) +
stanvar(array(mix[1,,drop=TRUE], dim=Nc), glue("{prefix}w"), scode=glue("vector[{prefix}Nc] {prefix}w;")) +
stanvar(t(mix[2:(p+1),,drop=FALSE]), glue("{prefix}m"), scode=glue("vector[{prefix}p] {prefix}m[{prefix}Nc];")) +
stanvar(Sigma, glue("{prefix}sigma"), scode=glue("matrix[{prefix}p, {prefix}p] {prefix}sigma[{prefix}Nc];")) +
stanvar(scode=glue("
matrix[{{prefix}}p, {{prefix}}p] {{prefix}}sigma_L[{{prefix}}Nc];
for (i in 1:{{prefix}}Nc) {
{{prefix}}sigma_L[i] = cholesky_decompose({{prefix}}sigma[i]);
}", .open="{{", .close="}}"), block="tdata")
if(verbose) {
mvprior <- mvprior +
stanvar(scode=glue('
print("Mixture prior {{name}}");
for(i in 1:{{prefix}}Nc) {
print("Component ", i, ": w = ", {{prefix}}w[i]);
print("Component ", i, ": m = ", {{prefix}}m[i]);
print("Component ", i, ": Sigma = ", {{prefix}}sigma[i]);
}
', .open="{{", .close="}}"), position="end", block="tdata")
}
mvprior
}
mix2brms.normMix <- function(mix, name, verbose=FALSE) {
Nc <- ncol(mix)
prefix <- paste0(name, "_")
prior <- stanvar(Nc, glue("{prefix}Nc")) +
stanvar(array(mix[1,,drop=TRUE], dim=Nc), glue("{prefix}w"), scode=glue("vector[{prefix}Nc] {prefix}w;")) +
stanvar(array(mix[2,,drop=TRUE], dim=Nc), glue("{prefix}m"), scode=glue("vector[{prefix}Nc] {prefix}m;")) +
stanvar(array(mix[3,,drop=TRUE], dim=Nc), glue("{prefix}s"), scode=glue("vector[{prefix}Nc] {prefix}s;"))
if(verbose) {
prior <- prior +
stanvar(name=paste0("verbose_", name), scode=glue('
print("Mixture prior {{name}}");
for(i in 1:{{prefix}}Nc) {
print("Component ", i, ": w = ", {{prefix}}w[i]);
print("Component ", i, ": m = ", {{prefix}}m[i]);
print("Component ", i, ": s = ", {{prefix}}s[i]);
}
', .open="{{", .close="}}"), position="end", block="tdata")
}
prior
}
## user facing function which one passes in all mixture densities one
## wants to use in a given brms model. See the examples below for
## their usage.
mixstanvar <- function(..., verbose=FALSE) {
default_variable_names <- lapply(rlang::enquos(...), as_label)
require(glue)
mixpriors <- list(...)
if(is.null(names(mixpriors))) {
variable_names <- default_variable_names
} else {
variable_names <- names(mixpriors)
not_set <- which(variable_names == "")
variable_names[not_set] <- default_variable_names[not_set]
}
sv <- mix2brms(mixpriors[[1]], variable_names[[1]], verbose)
for(i in seq_len(length(mixpriors)-1)) {
mix <- mixpriors[[i+1]]
variable <- variable_names[i+1]
sv <- sv + mix2brms(mix, variable, verbose)
}
includes_density <- function(density) any(sapply(mixpriors, inherits, density))
if(includes_density("mvnormMix")) {
sv <- sv + stanvar(name="mixmvnorm_lpdf", scode="
real mixmvnorm_lpdf(vector y, vector w, vector[] m, matrix[] L) {
int Nc = rows(w);
vector[Nc] lp_mix;
for(i in 1:Nc) {
lp_mix[i] = multi_normal_cholesky_lpdf(y | m[i], L[i]);
}
return log_sum_exp(log(w) + lp_mix);
}", block="functions")
}
if(includes_density("normMix")) {
sv <- sv + stanvar(name="mixnorm_lpdf", scode="
real mixnorm_lpdf(real y, vector w, vector m, vector s) {
int Nc = rows(w);
vector[Nc] lp_mix;
for(i in 1:Nc) {
lp_mix[i] = normal_lpdf(y | m[i], s[i]);
}
return log_sum_exp(log(w) + lp_mix);
}", block="functions")
}
sv
}
####### End utility functions which will be RBesT internal #################
set.seed(4365467)
let’s say we have some continuous covariate “l” measured per trial
need the covariance of the random effects by draw -> use rvar
facility from posterior for this. Then sample directly the
correlated MAP prior by sampling for each draw the MVN.
the mixstanvar command below creates a prior for the entire regression coefficient vector for the linear model (as we have ``center=FALSE). As prior for brms we have to refer to the "``mixmvnorm” density (defined by the mixstanvar command), which
we pass in three arguments referring to the weight, mean and
cholesky factor of the density. These three items are defined by
the `mixstanvar command within the context of the Stan model. The
names are given by object name (map in this case), then a “_” is
used as seperator between the respective element (w, m or sigma_L):
now compare results agains a joint fit of the data, which avoids
any approximation:
fit_mac <- brm(model_hist,
data=AS_cov,
prior=model_hist_prior,
seed=4767,
control=list(adapt_delta=0.99),
refresh=0, silent=0)
#> Start sampling
#> Running MCMC with 4 sequential chains...
#> Chain 1 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
#> Chain 1 Exception: binomial_logit_lpmf: Probability parameter[1] is nan, but must be finite! (in '/var/folders/71/2631nvxx07vdw464x8y8nk4h0000gn/T/RtmpawzKRm/model-2af011d2f2d4.stan', line 75, column 4 to column 50)
#> Chain 1 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
#> Chain 1 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
#> Chain 1
#> Chain 1 finished in 0.5 seconds.
#> Chain 2 finished in 0.4 seconds.
#> Chain 3 finished in 0.5 seconds.
#> Chain 4 finished in 0.6 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.5 seconds.
#> Total execution time: 2.4 seconds.
As one of the next features in
RBesT
I am intending to allow the use of mixture priors as defined inRBesT
in the context ofbrms
models. Below is some prototype for how this can look like in practice. This is a very first version (requiring RBesT 1.7.0) and feedback from users would be very welcome. I'd hope that the usage is clear from the examples given (for univariate normal and multivariate normal). I have iterated a bit of the design, but would like to hear from users if the syntax is easy and straightforward to use:Source code: issue-rbest-mvnmix2brms_R.txt
small example demonstrating a possible workflow on how one can use the posterior of one brms fit as prior for the next one. Use as an example the equivalence of MAP and MAC. MAP: fit hierarchical model, approximate the prior for a new unit and then fit the data item of interest. MAC: do a joint fit resulting in the same inference as compared to MAP.
let’s say we have some continuous covariate “l” measured per trial
let’s say we have 7 historical studies
MAP approach: fit the historical studies and then approximate predictive of mean for new study
need the covariance of the random effects by draw -> use rvar facility from posterior for this. Then sample directly the correlated MAP prior by sampling for each draw the MVN.
generate approximate parametric representation
now fit the current data using this as prior
the
mixstanvar command below creates a prior for the entire regression coefficient vector for the linear model (as we have ``center=FALSE). As prior for brms we have to refer to the "``mixmvnorm
” density (defined by themixstanvar
command), which we pass in three arguments referring to the weight, mean and cholesky factor of the density. These three items are defined by the `mixstanvar command within the context of the Stan model. The names are given by object name (map in this case), then a “_” is used as seperator between the respective element (w, m or sigma_L):For the univariate case we can get nice diagnostic plots to check that the EM worked fine:
now compare results agains a joint fit of the data, which avoids any approximation:
MAP and MAC approach have to match one another:
… which they do up to MCMC error. The univariate MAP discards the correlations and may not necessarily match the MAC result:
Stan code generated for mixture multivariate normal case:
Created on 2023-07-24 with reprex v2.0.2