openpharma / brms.mmrm

R package to run Bayesian MMRMs using {brms}
https://openpharma.github.io/brms.mmrm/
Other
18 stars 2 forks source link

Simulation-based calibration #56

Closed wlandau closed 11 months ago

wlandau commented 1 year ago

As part of #12, it would be good to run a simulation-based calibration study. Looking over the brm_simulate() function, I think it needs some work first:

As for simulation scenarios, I think it would be good to try:

Any other simulation scenarios come to mind?

wlandau commented 1 year ago

https://github.com/openpharma/brms.mmrm/issues/3#issuecomment-1698088230 is relevant and makes me think that every part of brm_simulate() needs a redesign before we work on SBC.

wlandau commented 1 year ago

I think I like brms::brm(sample_prior = "only") for brms.mmrm::brm_simulate() and custom R code for simulations in SBC. Before running the fully-scaled SBC pipeline, we could first compare the custom R code against brms::brm(sample_prior = "only") and check if they agree. That could eliminate potential bugs early on.

wlandau commented 1 year ago

I have made progress on this in the targets pipeline in https://github.com/openpharma/brms.mmrm/tree/56/sbc. The SBC rank statistics appear uniform for each scalar parameter, which is good. Going forward, I need to:

  1. Debug a rank deficiency issue in one of the simulation reps.
  2. Use custom R code to simulate data. (Currently we are simulating with brms and analyzing with brms, which is a good internal consistency check, but ultimately too circular for validation purposes.)
wlandau commented 1 year ago

I have made progress on simulation-based calibration. I implemented a fully cloud-native targets pipeline (running with https://wlandau.github.io/crew.aws.batch/ and https://books.ropensci.org/targets/cloud-storage.html) which runs the simulations and writes the SBC rank statistics to small compressed files in the vignettes folder. Then the sbc.Rmd package vignette shows the results. The pipeline is at https://github.com/openpharma/brms.mmrm/tree/56/sbc, and the vignette is at https://github.com/openpharma/brms.mmrm/blob/56/vignettes/sbc.Rmd. Here is a rendered copy: sbc.pdf

When I simulate data using brms and analyze it using brms, calibration looks fantastic. But when I write custom code to simulate from the model, calibration is terrible. The rendered vignette shows all this.

I think the next step is to use much simpler models. On my next attempt, I plan to use an intercept-only model, where the intercept is virtually a point mass (low variance, strict lower and upper bounds).

I am about to leave for a two-week vacation, and I will return on November 6. I will resume work on this sometime after I get back.

wlandau commented 1 year ago

I have a simplified intercept-only model ready to test when I return: https://github.com/openpharma/brms.mmrm/tree/56-simple/sbc. After running it, we will know whether the problem comes from fixed effects or from the variances/correlations.

wlandau commented 1 year ago

Just finished running the pipeline at https://github.com/openpharma/brms.mmrm/tree/c5fd72f32922f63bf693d1851b0732072466e963/sbc. The intercept looks calibrated:

Screenshot 2023-11-07 at 12 54 48 PM

But the standard deviation parameters are way off:

Screenshot 2023-11-07 at 12 54 57 PM

So are the correlations:

Screenshot 2023-11-07 at 12 55 03 PM

Maybe I can isolate it further if I try a simpler correlation structure. E.g. it could easily be a problem with simulating LKJ, maybe Stan disagrees with trialr::rlkjcorr().

wlandau commented 1 year ago

I took out the correlation matrix entirely:

> formula
response ~ 1 
sigma ~ 0 + time
> prior
        prior     class       coef group resp  dpar nlpar   lb   ub source
 normal(0, 1) Intercept                                   <NA> <NA>   user
 normal(0, 1)         b timetime_1            sigma       <NA> <NA>   user
 normal(0, 1)         b timetime_2            sigma       <NA> <NA>   user
 normal(0, 1)         b timetime_3            sigma       <NA> <NA>   user
 normal(0, 1)         b timetime_4            sigma       <NA> <NA>   user
> brms::make_stancode(formula = formula, data = data, prior = prior)
// generated with brms 2.20.4
functions {

}
data {
  int<lower=1> N; // total number of observations
  vector[N] Y; // response variable
  int<lower=1> K_sigma; // number of population-level effects
  matrix[N, K_sigma] X_sigma; // population-level design matrix
  int prior_only; // should the likelihood be ignored?
}
transformed data {

}
parameters {
  real Intercept; // temporary intercept for centered predictors
  vector[K_sigma] b_sigma; // regression coefficients
}
transformed parameters {
  real lprior = 0; // prior contributions to the log posterior
  lprior += normal_lpdf(Intercept | 0, 1);
  lprior += normal_lpdf(b_sigma[1] | 0, 1);
  lprior += normal_lpdf(b_sigma[2] | 0, 1);
  lprior += normal_lpdf(b_sigma[3] | 0, 1);
  lprior += normal_lpdf(b_sigma[4] | 0, 1);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    // initialize linear predictor term
    vector[N] sigma = rep_vector(0.0, N);
    mu += Intercept;
    sigma += X_sigma * b_sigma;
    sigma = exp(sigma);
    target += normal_lpdf(Y | mu, sigma);
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept;
}

And it looks like the intercept is calibrated fine:

Screenshot 2023-11-08 at 3 09 43 PM

but the standard deviations are way off:

Screenshot 2023-11-08 at 3 09 49 PM

Glad to isolate this down. Still a mystery.

wlandau commented 1 year ago

This is interesting: when I plot posterior medians against the truth, here is what I see:

Screenshot 2023-11-14 at 12 30 12 PM

The Stan code (edited above) does not claim to bound b_sigma from below...

wlandau commented 1 year ago

When I sample from the prior, I get standard normal draws.

model <- brms::brm(data = data, formula = formula, prior = prior, sample_prior = "only")
draws <- posterior::as_draws_df(model)
summary(draws$b_sigma_timetime_1)
#>      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
#> -3.731520 -0.661535  0.007676  0.027759  0.696233  3.334310 
wlandau commented 1 year ago

Also: convergence looks good, and I got no warnings about divergent transitions after warmup.

In practice, it may make little difference whether sigma is -2 or -3 because exp(-2) and exp(-3) are very small. Maybe that explains it? Maybe we should only simulate over a more sensible version of the prior?

I think the next step is to simplify this example even further and bring it to the Stan discourse.

wlandau commented 1 year ago

I think I simplified it down too far. This works as it should:

# Simulate a single dataset from the prior predictive distribution,
# run the model on the dataset,
# and calculate simulation-based calibration ranks.
simulate_once <- function(rep) {
  # Define a simple model.
  formula <- brms::brmsformula(y ~ 1, sigma ~ 0 + time)
  prior <- brms::set_prior(prior = "normal(0, 1)", class = "Intercept") +
    brms::set_prior(prior = "normal(0, 1)", class = "b", dpar = "sigma")
  # Simulate a dataset from the prior predictive distribution of the model.
  intercept <- rnorm(n = 1)
  b_sigma <- rnorm(n = 1)
  sigma <- exp(b_sigma)
  y <- rnorm(n = 100, mean = intercept, sd = sigma)
  data <- tibble::tibble(y = y, time = 1L)
  # Run the model.
  model <- brms::brm(formula = formula, data = data, prior = prior)
  # Get SBC ranks and other summaries.
  names <- c("b_Intercept", "b_sigma_time")
  draws <- posterior::as_draws_matrix(model)[, names]
  truth <- c(b_Intercept = intercept, b_sigma_time = b_sigma)[names]
  ranks <- SBC::calculate_ranks_draws_matrix(variables = truth, dm = draws)
  summary <- posterior::summarize_draws(draws)
  tibble::tibble(summary, truth = truth, rank = ranks, rep = rep)
}

# Run 1000 simulations using {crew} for parallel computing.
controller <- crew::crew_controller_local(workers = 4)
tasks <- controller$map(
  command = simulate_once(rep),
  iterate = list(rep = seq_len(1000)),
  data = list(simulate_once = simulate_once)
)
results <- dplyr::bind_rows(tasks$result)

# Plot the SBC ranks.
library(dplyr)
library(ggplot2)
plot_ranks <- function(results, parameter) {
  subset <- results %>%
    filter(variable == parameter)
  ggplot(subset) +
    geom_histogram(
      mapping = aes(x = rank),
      breaks = seq(from = 0, to = max(subset$rank), length.out = 10)
    ) +
    theme_gray(16)
}
plot_ranks(results, "b_Intercept")

Screenshot 2023-11-14 at 4 01 10 PM

plot_ranks(results, "b_sigma_time")

Screenshot 2023-11-14 at 4 02 34 PM

wlandau commented 1 year ago

Turns out the bug in https://github.com/openpharma/brms.mmrm/issues/56#issuecomment-1802625388 was due to managing intercept parameters in the simulation code. I will see if there is any way to increase the complexity of the models back to what it was.

wlandau commented 1 year ago

Rerunning with the full correlation matrix and only the intercept for fixed effects.

wlandau commented 1 year ago

Odd... the results in https://github.com/openpharma/brms.mmrm/commit/10d4af504cb5495b7bad0d2a1bd54fc5952bebf0 seem to be well calibrated. There, the model has only and intercept for fixed effects, and it has the full unstructured covariance matrix. So then I guess I should try and figure out what is happening with the more complicated fixed effects in the full model.

wlandau commented 1 year ago

Now trying to gradually add back fixed effects in https://github.com/openpharma/brms.mmrm/tree/56-fixed-effects.

wlandau commented 1 year ago

But then removing the covariance matrix makes the fixed effects calibrated again...

I don't know where, but I am convinced there is a bug in my simulation code. I tried to isolate it with simpler models, but that doesn't seem to be working. Instead, I should just discard the current simulation code and start from scratch using a completely different approach.

wlandau commented 11 months ago

I attempted another reproducible example. Fortunately, I was able to make it simple enough to share and still reproduce the poor calibration I am seeing. It's odd: the response ~ 0 + group + unstr(time = time, gr = patient) model is poorly calibrated, but the calibration of response ~ unstr(time = time, gr = patient) and response ~ 0 + group look just fine. Currently working primarily in branch 56-reprex. Code:

library(brms)
library(dplyr)
library(tibble)
library(tidyr)

#############
# FUNCTIONS #
#############

one_sbc_replication <- function(
  chains = 4L,
  warmup = 2000L,
  iter = 4000L
) {
  prior <- set_prior("lkj_corr_cholesky(1)", class = "Lcortime") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_1") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_2") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_3") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_1") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_2") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_3") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_4")
  formula <- brmsformula(
    formula = response ~ 0 + group + unstr(time = time, gr = patient),
    sigma ~ 0 + time
  )
  simulation <- simulate_data(formula = formula, prior = prior)
  options(brms.backend = "rstan")
  model <- brm(
    data = simulation$data,
    formula = formula,
    prior = prior,
    chains = chains,
    cores = chains,
    iter = iter,
    warmup = warmup
  )
  get_sbc_ranks(model, simulation)
}

simulate_data <- function(formula, prior) {
  n_group <- 3L
  n_patient <- 100L
  n_time <- 4L
  patients <- tibble(
    group = paste0("group_", rep(seq_len(n_group), each = n_patient)),
    patient = paste0("patient_", seq_len(n_group * n_patient))
  )
  data <- expand_grid(patients, time = paste0("time_", seq_len(n_time)))
  data$response <- 0
  x <- make_standata(formula, data, prior = prior)$X
  beta <- rnorm(n = n_group, mean = 0, sd = 1)
  names(beta) <- paste0("b_", colnames(x))
  b_sigma <- rnorm(n = n_time, mean = 0, sd = 1)
  names(b_sigma) <- paste0("b_sigma_timetime_", seq_len(n_time))
  sigma <- exp(b_sigma)
  correlation <- trialr::rlkjcorr(n = 1L, K = n_time, eta = 1)
  i <- rep(seq_len(n_time), each = n_time)
  j <- rep(seq_len(n_time), times = n_time)
  cortime <- as.numeric(correlation)[j > i]
  names(cortime) <- sprintf("cortime__time_%s__time_%s", i[j > i], j[j > i])
  covariance <- diag(sigma) %*% correlation %*% diag(sigma)
  data <- data |>
    mutate(mu = as.numeric(x %*% beta)) |>
    mutate(index_patient = rep(seq_len(n_patient * n_group), each = n_time)) |>
    group_by(index_patient) |>
    mutate(response = MASS::mvrnorm(mu = mu, Sigma = covariance)) |>
    ungroup() |>
    select(-index_patient, -mu)
  parameters <- c(beta, b_sigma, cortime)
  stopifnot(!anyDuplicated(names(parameters)))
  list(data = data, parameters = parameters)
}

get_sbc_ranks <- function(model, simulation) {
  draws <- posterior::as_draws_matrix(model)
  draws <- draws[, setdiff(colnames(draws), c("lprior", "lp__"))]
  truth <- simulation$parameters
  stopifnot(all(sort(names(truth)) == sort(colnames(draws))))
  draws <- draws[, names(truth)]
  ranks <- SBC::calculate_ranks_draws_matrix(variables = truth, dm = draws)
  tibble::as_tibble(as.list(ranks))
}

##############
# SIMULATION #
##############

# I used an SGE cluster, so this is how I set up the crew controller:
controller <- crew.cluster::crew_controller_sge(
  name = "brms-mmrm-sbc",
  workers = 100L,
  seconds_idle = 30,
  seconds_launch = 1800,
  launch_max = 3L,
  script_lines = "module load R/4.2.2",
  sge_cores = 4L
)

# But if you have different resources, you may want to choose
# a different crew launcher plugin, e.g.:
# controller <- crew::crew_controller_local()

# Run the simulations:
controller$start()
tasks <- controller$map(
  command = one_sbc_replication(chains = 4L, warmup = 2000L, iter = 4000L),
  iterate = list(index = seq_len(100L)),
  globals = as.list(globalenv()),
  packages = c("brms", "dplyr", "tibble", "tidyr")
)
controller$terminate()
simulations <- bind_rows(tasks$result)

###########
# RESULTS #
###########

library(tidyr)
results <- pivot_longer(
  simulations,
  cols = everything(),
  names_to = "parameter",
  values_to = "rank"
)

library(ggplot2)
plot <- ggplot(results) +
  geom_histogram(
    aes(x = rank),
    breaks = seq(from = 0, to = max(results$rank), length.out = 10)
  ) +
  facet_wrap(~parameter) +
  theme_gray(16)
ggsave("plot.png", plot, width = 12, height = 10)

SBC rank histograms:

reprex

wlandau commented 11 months ago

Just posted this reprex to https://discourse.mc-stan.org/t/trouble-validating-a-bayesian-mmrm-implemented-with-brms/33564.

wlandau commented 11 months ago

Andrew Johnson's reply on the Stan discourse made me think I should look at a single simulated dataset. I picked a seed which gave terrible rank statistics in the simulation, and I plotted the marginal posteriors (red) against the true parameters (blue). It looks like the treatment group labels in the data are getting switched around. I see this same pattern for multiple seeds. Is there a reason brms might reorder character labels?

Screenshot 2023-12-13 at 12 54 56 PM

Relative to the parameters, it looks like the model wants to move group 1 to group 3, group 2 to group 1, and group 3 to group 2. Does anyone with more brms experience know why this permutation is happening?

Here is a quick reprex. It runs in about 1.3 minutes on my local machine, and the built-in convergence diagnostics look great.

library(brms)
library(dplyr)
library(ggplot2)
library(posterior)
library(tibble)
library(tidyr)

# Define the model.
prior <- set_prior("lkj_corr_cholesky(1)", class = "Lcortime") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_1") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_2") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_3") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_1") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_2") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_3") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_4")
formula <- brmsformula(
  formula = response ~ 0 + group + unstr(time = time, gr = patient),
  sigma ~ 0 + time
)

# Simulate a dataset from the prior.
set.seed(seed = 8L, kind = "Mersenne-Twister")
n_group <- 3L
n_patient <- 100L
n_time <- 4L
patients <- tibble(
  group = paste0("group_", rep(seq_len(n_group), each = n_patient)),
  patient = paste0("patient_", seq_len(n_group * n_patient))
)
data <- expand_grid(patients, time = paste0("time_", seq_len(n_time)))
data$response <- 0
x <- make_standata(formula, data, prior = prior)$X
beta <- rnorm(n = n_group, mean = 0, sd = 1)
names(beta) <- paste0("b_", colnames(x))
b_sigma <- rnorm(n = n_time, mean = 0, sd = 1)
names(b_sigma) <- paste0("b_sigma_timetime_", seq_len(n_time))
sigma <- exp(b_sigma)
correlation <- trialr::rlkjcorr(n = 1L, K = n_time, eta = 1)
i <- rep(seq_len(n_time), each = n_time)
j <- rep(seq_len(n_time), times = n_time)
cortime <- as.numeric(correlation)[j > i]
names(cortime) <- sprintf("cortime__time_%s__time_%s", i[j > i], j[j > i])
covariance <- diag(sigma) %*% correlation %*% diag(sigma)
data <- data |>
  mutate(mu = as.numeric(x %*% beta)) |>
  mutate(index_patient = rep(seq_len(n_patient * n_group), each = n_time)) |>
  group_by(index_patient) |>
  mutate(response = MASS::mvrnorm(mu = mu, Sigma = covariance)) |>
  ungroup() |>
  select(-index_patient, -mu)
parameters <- c(beta, b_sigma, cortime)

# Run the model.
options(brms.backend = "rstan")
model <- brm(
  data = data,
  formula = formula,
  prior = prior,
  seed = 1L,
  chains = 4L,
  cores = 4L,
  iter = 4000L,
  warmup = 2000L,
  refresh = 10L
)

# Visualize the fixed effect marginal posteriors against the data.
summary_model <- summarize_draws(model)
summary_fixed_model <- summary_model |>
  select(variable, mean, q5, q95) |>
  filter(grepl("group", variable))
z <- qnorm(p = 0.9)
summary_fixed_data <- data |>
  group_by(group) |>
  summarize(
    mean = mean(response),
    q5 = mean - z * sd(response) / sqrt(n()),
    q95 = mean + z * sd(response) / sqrt(n()),
    .groups = "drop"
  ) |>
  rename(variable = group) |>
  mutate(variable = paste0("b_group", variable))
summary_fixed <- dplyr::bind_rows(
  model = summary_fixed_model,
  data = summary_fixed_data,
  .id = "source"
)
summary_parameters <- tibble(
  variable = names(parameters),
  value = unname(parameters)
)
summary_parameters_fixed <- summary_parameters |>
  filter(grepl("group", variable))

ggplot(summary_fixed_data) +
  geom_point(
    aes(x = variable, y = mean),
    color = "red",
    position = position_dodge(width = 0.5)
  ) +
  geom_errorbar(
    aes(x = variable, ymin = q5, ymax = q95),
    color = "red",
    position = position_dodge(width = 0.5)
  ) +
  geom_point(
    data = summary_parameters_fixed,
    mapping = aes(x = variable, y = value),
    color = "blue",
    position = position_dodge(width = 0.5)
  ) +
  ylab("value") +
  theme_gray(24)
wlandau commented 11 months ago

I think I figured it out: make_standata() was giving me a model matrix whose rows do not match the original data. When I recovered the original row order, the results started making sense. I think this should fix the SBC study. Details are at https://discourse.mc-stan.org/t/trouble-validating-a-bayesian-mmrm-implemented-with-brms/33564/5.

wlandau commented 11 months ago

So the row reordering definitely fixes https://github.com/openpharma/brms.mmrm/issues/56#issuecomment-1852845276, but unfortunately the original SBC study appears poorly calibrated. The standard deviations and correlations look great, but the fixed effects are still way off (all ranks either equal to 0 or the number of MCMC iterations). Looks like another R code error.

wlandau commented 11 months ago

So the row reordering definitely fixes https://github.com/openpharma/brms.mmrm/issues/56#issuecomment-1852845276, but unfortunately the original SBC study appears poorly calibrated. The standard deviations and correlations look great, but the fixed effects are still way off (all ranks either equal to 0 or the number of MCMC iterations). Looks like another R code error.

Solved it! In the data simulation code, I just needed to make the row order in the brms prior match the columns of the model matrix from brms::make_standata(). PR forthcoming.

wlandau commented 11 months ago

SBC on the non-subgroup model is complete and successful, and the results are in a new vignette at https://openpharma.github.io/brms.mmrm/articles/sbc.html.