Novartis / RBesT

Tool-set to support Bayesian evidence synthesis in R
https://opensource.nibr.com/RBesT/
GNU General Public License v3.0
20 stars 2 forks source link

mixtures as priors for brms #12

Closed weberse2 closed 1 year ago

weberse2 commented 1 year ago

As one of the next features in RBesT I am intending to allow the use of mixture priors as defined in RBesT in the context of brms models. Below is some prototype for how this can look like in practice. This is a very first version (requiring RBesT 1.7.0) and feedback from users would be very welcome. I'd hope that the usage is clear from the examples given (for univariate normal and multivariate normal). I have iterated a bit of the design, but would like to hear from users if the syntax is easy and straightforward to use:

Source code: issue-rbest-mvnmix2brms_R.txt

small example demonstrating a possible workflow on how one can use the posterior of one brms fit as prior for the next one. Use as an example the equivalence of MAP and MAC. MAP: fit hierarchical model, approximate the prior for a new unit and then fit the data item of interest. MAC: do a joint fit resulting in the same inference as compared to MAP.

library(RBesT)
#> This is RBesT version 1.7.0 (released 2023-07-20, git-sha d5f8521)
library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.19.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
library(posterior)
#> This is posterior version 1.4.1
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var
#> The following objects are masked from 'package:base':
#> 
#>     %in%, match
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(knitr)
library(mvtnorm)
library(bayesplot)
#> Warning: package 'bayesplot' was built under R version 4.1.1
#> This is bayesplot version 1.10.0
#> - Online documentation and vignettes at mc-stan.org/bayesplot
#> - bayesplot theme set to bayesplot::theme_default()
#>    * Does _not_ affect other ggplot2 plots
#>    * See ?bayesplot_theme_set for details on theme setting
#> 
#> Attaching package: 'bayesplot'
#> The following object is masked from 'package:posterior':
#> 
#>     rhat
#> The following object is masked from 'package:brms':
#> 
#>     rhat

suppressPackageStartupMessages(library(here))
# instruct brms to use cmdstanr as backend and cache all Stan binaries
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=here("brms-cache"))
# create cache directory if not yet available
dir.create(here("brms-cache"), FALSE)
options(mc.sores=4)

####### Utility functions which will be RBesT internal #################
mix2brms <- function(mix, name, verbose=FALSE) UseMethod("mix2brms")

mix2brms.mvnormMix <- function(mix, name, verbose=FALSE) {
    Nc <- ncol(mix)
    p <- RBesT:::mvnormdim(mix[-1,1])
    Sigma <- array(NA, dim=c(Nc, p, p))
    for(i in 1:Nc) {
        Rho_c <- diag(nrow=p)
        Rho_c[lower.tri(Rho_c)] <- mix[(1+2*p+1):nrow(mix),i,drop=FALSE]
        Rho_c[upper.tri(Rho_c)] <- t(Rho_c)[upper.tri(Rho_c)]
        s <- mix[(1+p+1):(1+p+p),i,drop=TRUE]
        Sigma[i,,] <- diag(s, nrow=p) %*% Rho_c %*% diag(s, nrow=p)
    }
    prefix <- paste0(name, "_")
    mvprior <-  stanvar(Nc, glue("{prefix}Nc")) +
        stanvar(p, glue("{prefix}p")) +
        stanvar(array(mix[1,,drop=TRUE], dim=Nc), glue("{prefix}w"), scode=glue("vector[{prefix}Nc] {prefix}w;")) +
        stanvar(t(mix[2:(p+1),,drop=FALSE]), glue("{prefix}m"), scode=glue("vector[{prefix}p] {prefix}m[{prefix}Nc];")) +
        stanvar(Sigma, glue("{prefix}sigma"), scode=glue("matrix[{prefix}p, {prefix}p] {prefix}sigma[{prefix}Nc];")) +
        stanvar(scode=glue("
matrix[{{prefix}}p, {{prefix}}p] {{prefix}}sigma_L[{{prefix}}Nc];
for (i in 1:{{prefix}}Nc) {
    {{prefix}}sigma_L[i] = cholesky_decompose({{prefix}}sigma[i]);
}", .open="{{", .close="}}"), block="tdata")
    if(verbose) {
        mvprior <- mvprior +
            stanvar(scode=glue('
print("Mixture prior {{name}}");
for(i in 1:{{prefix}}Nc) {
  print("Component ", i, ": w     = ", {{prefix}}w[i]);
  print("Component ", i, ": m     = ", {{prefix}}m[i]);
  print("Component ", i, ": Sigma = ", {{prefix}}sigma[i]);
}
', .open="{{", .close="}}"), position="end", block="tdata")
    }
    mvprior
}

mix2brms.normMix <- function(mix, name, verbose=FALSE) {
    Nc <- ncol(mix)
    prefix <- paste0(name, "_")
    prior <-  stanvar(Nc, glue("{prefix}Nc")) +
        stanvar(array(mix[1,,drop=TRUE], dim=Nc), glue("{prefix}w"), scode=glue("vector[{prefix}Nc] {prefix}w;")) +
        stanvar(array(mix[2,,drop=TRUE], dim=Nc), glue("{prefix}m"), scode=glue("vector[{prefix}Nc] {prefix}m;")) +
        stanvar(array(mix[3,,drop=TRUE], dim=Nc), glue("{prefix}s"), scode=glue("vector[{prefix}Nc] {prefix}s;"))
    if(verbose) {
        prior <- prior +
            stanvar(name=paste0("verbose_", name), scode=glue('
print("Mixture prior {{name}}");
for(i in 1:{{prefix}}Nc) {
  print("Component ", i, ": w = ", {{prefix}}w[i]);
  print("Component ", i, ": m = ", {{prefix}}m[i]);
  print("Component ", i, ": s = ", {{prefix}}s[i]);
}
', .open="{{", .close="}}"), position="end", block="tdata")
    }
    prior
}

## user facing function which one passes in all mixture densities one
## wants to use in a given brms model. See the examples below for
## their usage.
mixstanvar <- function(..., verbose=FALSE) {
    default_variable_names <- lapply(rlang::enquos(...), as_label)
    require(glue)
    mixpriors <- list(...)
    if(is.null(names(mixpriors))) {
        variable_names <- default_variable_names
    } else {
        variable_names <- names(mixpriors)
        not_set <- which(variable_names == "")
        variable_names[not_set] <- default_variable_names[not_set]
    }
    sv <- mix2brms(mixpriors[[1]], variable_names[[1]], verbose)
    for(i in seq_len(length(mixpriors)-1)) {
        mix <- mixpriors[[i+1]]
        variable <- variable_names[i+1]
        sv <- sv + mix2brms(mix, variable, verbose)
    }
    includes_density <- function(density) any(sapply(mixpriors, inherits, density))

    if(includes_density("mvnormMix")) {
        sv <- sv + stanvar(name="mixmvnorm_lpdf", scode="
real mixmvnorm_lpdf(vector y, vector w, vector[] m, matrix[] L) {
  int Nc = rows(w);
  vector[Nc] lp_mix;
  for(i in 1:Nc) {
     lp_mix[i] = multi_normal_cholesky_lpdf(y | m[i], L[i]);
  }
  return log_sum_exp(log(w) + lp_mix);
}", block="functions")
    }
    if(includes_density("normMix")) {
        sv <- sv + stanvar(name="mixnorm_lpdf", scode="
real mixnorm_lpdf(real y, vector w, vector m, vector s) {
  int Nc = rows(w);
  vector[Nc] lp_mix;
  for(i in 1:Nc) {
     lp_mix[i] = normal_lpdf(y | m[i], s[i]);
  }
  return log_sum_exp(log(w) + lp_mix);
}", block="functions")
    }

    sv
}
####### End utility functions which will be RBesT internal #################

set.seed(4365467)

let’s say we have some continuous covariate “l” measured per trial

AS_cov <- bind_cols(AS, l=rnorm(8))
kable(AS_cov)
study n r l
Study 1 107 23 -0.3961712
Study 2 44 12 -0.5931832
Study 3 51 19 -0.0627691
Study 4 39 9 -0.0083993
Study 5 139 39 1.1878420
Study 6 20 6 0.4626915
Study 7 78 9 -0.8153975
Study 8 35 10 0.6444389

let’s say we have 7 historical studies

AS_cov_hist <- AS_cov[-8,]
AS_cov_current <- AS_cov[8,,drop=FALSE]

MAP approach: fit the historical studies and then approximate predictive of mean for new study

model_hist <- bf(r | trials(n) ~ 1 + l + (1 + l | study), family=binomial)
#get_prior(model_hist, AS_cov_hist)

model_hist_prior <- prior(normal(0, 2), class="Intercept") +
    prior(normal(0, 0.5), class="sd", coef="Intercept", group="study") +
    prior(normal(0, 0.25), class="sd", coef="l", group="study")

map_fit <- brm(model_hist,
               data=AS_cov_hist,
               prior=model_hist_prior,
               seed=4767,
               control=list(adapt_delta=0.99),
               refresh=0, silent=0)
#> Start sampling
#> Running MCMC with 4 sequential chains...
#> 
#> Chain 1 finished in 0.5 seconds.
#> Chain 2 finished in 0.5 seconds.
#> Chain 3 finished in 0.5 seconds.
#> Chain 4 finished in 0.5 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.5 seconds.
#> Total execution time: 2.5 seconds.

post_rv <- as_draws_rvars(map_fit)

need the covariance of the random effects by draw -> use rvar facility from posterior for this. Then sample directly the correlated MAP prior by sampling for each draw the MVN.

map_mc <- with(post_rv, {
    D <- diag(c(sd_study__Intercept, sd_study__l))
    Rho <- cbind(c(rvar(1), cor_study__Intercept__l), c(cor_study__Intercept__l, rvar(1)))
    Sigma <- D %**% Rho %**% D
    drop(rdo(rmvnorm(1, c(b_Intercept, b_l), Sigma)))
    })

names(map_mc) <- c("intercept", "slope")
mcmc_pairs(draws_of(map_mc, with_chains=TRUE))

generate approximate parametric representation

map <- mixfit(draws_of(map_mc), "mvnorm", Nc=3)
map
#> EM for Multivariate Normal Mixture Model
#> Log-Likelihood = -3939.767
#> 
#> Multivariate normal mixture
#> Outcome dimension: 2
#> Mixture Components:
#>          comp1      comp2      comp3     
#> w         0.3753268  0.3211428  0.3035304
#> m[1]     -1.1738439 -1.0623154 -1.0479842
#> m[2]      0.4049649  0.2383634  0.4106648
#> s[1]      0.2855314  0.2091845  0.6618486
#> s[2]      0.2984324  0.2936100  0.5979352
#> rho[2,1]  0.2063437 -0.2393784  0.1371571

now fit the current data using this as prior

model_current <- bf(r | trials(n) ~ 1 + l, family=binomial, center=FALSE)

#get_prior(model_current, AS_cov_current)

the mixstanvar command below creates a prior for the entire regression coefficient vector for the linear model (as we have ``center=FALSE). As prior for brms we have to refer to the "``mixmvnorm” density (defined by the mixstanvar command), which we pass in three arguments referring to the weight, mean and cholesky factor of the density. These three items are defined by the `mixstanvar command within the context of the Stan model. The names are given by object name (map in this case), then a “_” is used as seperator between the respective element (w, m or sigma_L):

model_current_prior <- prior(mixmvnorm(map_w, map_m, map_sigma_L), class="b")

fit_map <- brm(model_current,
               data=AS_cov_current,
               prior=model_current_prior,
               stanvars=mixstanvar(map, verbose=TRUE),
               seed=4767,
               control=list(adapt_delta=0.99),
               refresh=0, silent=0)
#> Loading required package: glue
#> Start sampling
#> Running MCMC with 4 sequential chains...
#> 
#> Chain 1 Mixture prior map 
#> Chain 1 Component 1: w     = 0.375327 
#> Chain 1 Component 1: m     = [-1.17384,0.404965] 
#> Chain 1 Component 1: Sigma = [[0.0815282,0.0175829],[0.0175829,0.0890619]] 
#> Chain 1 Component 2: w     = 0.321143 
#> Chain 1 Component 2: m     = [-1.06232,0.238363] 
#> Chain 1 Component 2: Sigma = [[0.0437582,-0.0147023],[-0.0147023,0.0862069]] 
#> Chain 1 Component 3: w     = 0.30353 
#> Chain 1 Component 3: m     = [-1.04798,0.410665] 
#> Chain 1 Component 3: Sigma = [[0.438044,0.0542789],[0.0542789,0.357527]] 
#> Chain 1 finished in 0.1 seconds.
#> Chain 2 Mixture prior map 
#> Chain 2 Component 1: w     = 0.375327 
#> Chain 2 Component 1: m     = [-1.17384,0.404965] 
#> Chain 2 Component 1: Sigma = [[0.0815282,0.0175829],[0.0175829,0.0890619]] 
#> Chain 2 Component 2: w     = 0.321143 
#> Chain 2 Component 2: m     = [-1.06232,0.238363] 
#> Chain 2 Component 2: Sigma = [[0.0437582,-0.0147023],[-0.0147023,0.0862069]] 
#> Chain 2 Component 3: w     = 0.30353 
#> Chain 2 Component 3: m     = [-1.04798,0.410665] 
#> Chain 2 Component 3: Sigma = [[0.438044,0.0542789],[0.0542789,0.357527]] 
#> Chain 2 finished in 0.1 seconds.
#> Chain 3 Mixture prior map 
#> Chain 3 Component 1: w     = 0.375327 
#> Chain 3 Component 1: m     = [-1.17384,0.404965] 
#> Chain 3 Component 1: Sigma = [[0.0815282,0.0175829],[0.0175829,0.0890619]] 
#> Chain 3 Component 2: w     = 0.321143 
#> Chain 3 Component 2: m     = [-1.06232,0.238363] 
#> Chain 3 Component 2: Sigma = [[0.0437582,-0.0147023],[-0.0147023,0.0862069]] 
#> Chain 3 Component 3: w     = 0.30353 
#> Chain 3 Component 3: m     = [-1.04798,0.410665] 
#> Chain 3 Component 3: Sigma = [[0.438044,0.0542789],[0.0542789,0.357527]] 
#> Chain 3 finished in 0.1 seconds.
#> Chain 4 Mixture prior map 
#> Chain 4 Component 1: w     = 0.375327 
#> Chain 4 Component 1: m     = [-1.17384,0.404965] 
#> Chain 4 Component 1: Sigma = [[0.0815282,0.0175829],[0.0175829,0.0890619]] 
#> Chain 4 Component 2: w     = 0.321143 
#> Chain 4 Component 2: m     = [-1.06232,0.238363] 
#> Chain 4 Component 2: Sigma = [[0.0437582,-0.0147023],[-0.0147023,0.0862069]] 
#> Chain 4 Component 3: w     = 0.30353 
#> Chain 4 Component 3: m     = [-1.04798,0.410665] 
#> Chain 4 Component 3: Sigma = [[0.438044,0.0542789],[0.0542789,0.357527]] 
#> Chain 4 finished in 0.1 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.1 seconds.
#> Total execution time: 0.4 seconds.

umap_inter <- mixfit(draws_of(map_mc)[,"intercept"], "norm", Nc=3)
umap_slope <- mixfit(draws_of(map_mc)[,"slope"], "norm", Nc=3)

For the univariate case we can get nice diagnostic plots to check that the EM worked fine:

plot(umap_inter)$mix

plot(umap_slope)$mix


model_current_univ_prior <-
    prior(mixnorm(umap_inter_w, umap_inter_m, umap_inter_s), class="b", coef="Intercept") +
    prior(mixnorm(umap_slope_w, umap_slope_m, umap_slope_s), class="b", coef="l")

fit_umap <- brm(model_current,
                data=AS_cov_current,
                prior=model_current_univ_prior,
                stanvars=mixstanvar(umap_inter, umap_slope),
                seed=4767,
                control=list(adapt_delta=0.99),
                refresh=0, silent=0)
#> Start sampling
#> Running MCMC with 4 sequential chains...
#> 
#> Chain 1 finished in 0.0 seconds.
#> Chain 2 finished in 0.1 seconds.
#> Chain 3 finished in 0.1 seconds.
#> Chain 4 finished in 0.1 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.1 seconds.
#> Total execution time: 0.5 seconds.

now compare results agains a joint fit of the data, which avoids any approximation:

fit_mac <- brm(model_hist,
               data=AS_cov,
               prior=model_hist_prior,
               seed=4767,
               control=list(adapt_delta=0.99),
               refresh=0, silent=0)
#> Start sampling
#> Running MCMC with 4 sequential chains...
#> Chain 1 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
#> Chain 1 Exception: binomial_logit_lpmf: Probability parameter[1] is nan, but must be finite! (in '/var/folders/71/2631nvxx07vdw464x8y8nk4h0000gn/T/RtmpawzKRm/model-2af011d2f2d4.stan', line 75, column 4 to column 50)
#> Chain 1 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
#> Chain 1 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
#> Chain 1
#> Chain 1 finished in 0.5 seconds.
#> Chain 2 finished in 0.4 seconds.
#> Chain 3 finished in 0.5 seconds.
#> Chain 4 finished in 0.6 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.5 seconds.
#> Total execution time: 2.4 seconds.

MAP and MAC approach have to match one another:

fitted(fit_mac)[8,]
#>  Estimate Est.Error      Q2.5     Q97.5 
#> 10.120007  1.917051  6.602673 14.259834
fitted(fit_map)[1,]
#>  Estimate Est.Error      Q2.5     Q97.5 
#> 10.093584  1.839139  6.625749 13.926558

… which they do up to MCMC error. The univariate MAP discards the correlations and may not necessarily match the MAC result:

fitted(fit_mac)[8,]
#>  Estimate Est.Error      Q2.5     Q97.5 
#> 10.120007  1.917051  6.602673 14.259834
fitted(fit_umap)[1,]
#>  Estimate Est.Error      Q2.5     Q97.5 
#> 10.130354  1.967907  6.460267 14.339082

Stan code generated for mixture multivariate normal case:

cat(stancode(fit_map))
#> // generated with brms 2.19.0
#> functions {
#>   
#> real mixmvnorm_lpdf(vector y, vector w, vector[] m, matrix[] L) {
#>   int Nc = rows(w);
#>   vector[Nc] lp_mix;
#>   for(i in 1:Nc) {
#>      lp_mix[i] = multi_normal_cholesky_lpdf(y | m[i], L[i]);
#>   }
#>   return log_sum_exp(log(w) + lp_mix);
#> }
#> }
#> data {
#>   int<lower=1> N;  // total number of observations
#>   int Y[N];  // response variable
#>   int trials[N];  // number of trials
#>   int<lower=1> K;  // number of population-level effects
#>   matrix[N, K] X;  // population-level design matrix
#>   int prior_only;  // should the likelihood be ignored?
#>   int map_Nc;
#>   int map_p;
#>   vector[map_Nc] map_w;
#>   vector[map_p] map_m[map_Nc];
#>   matrix[map_p, map_p] map_sigma[map_Nc];
#> }
#> transformed data {
#>   matrix[map_p, map_p] map_sigma_L[map_Nc];
#> for (i in 1:map_Nc) {
#>     map_sigma_L[i] = cholesky_decompose(map_sigma[i]);
#> }
#>   print("Mixture prior map");
#> for(i in 1:map_Nc) {
#>   print("Component ", i, ": w     = ", map_w[i]);
#>   print("Component ", i, ": m     = ", map_m[i]);
#>   print("Component ", i, ": Sigma = ", map_sigma[i]);
#> }
#> }
#> parameters {
#>   vector[K] b;  // population-level effects
#> }
#> transformed parameters {
#>   real lprior = 0;  // prior contributions to the log posterior
#>   lprior += mixmvnorm_lpdf(b | map_w, map_m, map_sigma_L);
#> }
#> model {
#>   // likelihood including constants
#>   if (!prior_only) {
#>     // initialize linear predictor term
#>     vector[N] mu = rep_vector(0.0, N);
#>     mu += X * b;
#>     target += binomial_logit_lpmf(Y | trials, mu);
#>   }
#>   // priors including constants
#>   target += lprior;
#> }
#> generated quantities {
#> }

Created on 2023-07-24 with reprex v2.0.2

weberse2 commented 1 year ago

implemented in version 1.7-1 released to CRAN