paul-buerkner / brms

brms R package for Bayesian generalized multivariate non-linear multilevel models using Stan
https://paul-buerkner.github.io/brms/
GNU General Public License v2.0
1.27k stars 182 forks source link

Possible bug in brms/emmeans integration #1654

Closed wlandau closed 2 months ago

wlandau commented 4 months ago

Related: #1630, https://discourse.mc-stan.org/t/trouble-with-brms-emmeans-integration/34664. I am posting here because I think the issue might be a bug in brms, and the comment section in my Stan Discourse post has not been active.

brms integrates with emmeans for marginal mean calculations, but the results seem off. The reprex below uses the mmrm package's FEV1 dataset, a simulation of a clinical trial with treatment groups in ARMCD and discrete time points for repeated measures in AVISIT. The example compares 4 different methods of estimating marginal means for each combination of ARMCD and AVISIT:

  1. Data summaries: compute means and independent frequentist 95% confidence intervals on the raw data.
  2. lm() + emmeans: fit a model with lm() and get marginal means with emmeans.
  3. brms + custom: fit a model with brms and use a custom linear transformation to map model parameters to marginal means.
  4. brms + emmeans: use the native brms/emmeans integration to estimate marginal means from the fitted brms model.

There is reasonable agreement among approaches (1), (2), and (3), and approach (4) gives very different results from all the others. I ran the following on the current development version of brms in the master branch (https://github.com/paul-buerkner/brms/commit/298b947fa9cfb914aeb7cb3aab7974aa179682b1)

suppressPackageStartupMessages({
  library(brms)
  library(coda)
  library(emmeans)
  library(mmrm)
  library(posterior)
  library(tidyverse)
  library(zoo)
})
emm_options(sep = "|")

packageDescription("brms")$GithubSHA1
#> [1] "298b947fa9cfb914aeb7cb3aab7974aa179682b1"

# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.)
data(fev_data, package = "mmrm")
data <- fev_data %>%
  mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
  select(-FEV1) %>%
  group_by(USUBJID) %>%
  complete(
    AVISIT,
    fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
  ) %>%
  ungroup() %>%
  arrange(USUBJID, AVISIT) %>%
  group_by(USUBJID) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
  ungroup() %>%
  filter(!is.na(FEV1_CHG))
summary_data <- data %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "1_data",
    mean = mean(FEV1_CHG),
    lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    .groups = "drop"
  )

# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
  AVISIT + RACE + SEX + WEIGHT

# lm with emmeans
model_lm <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
  object = model_lm,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
  rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
  mutate(source = "2_lm_emmeans")

# brms with emmeans
model_brms <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
  rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
  mutate(source = "4_brms_emmeans")

# custom marginal means from brms draws using a custom mapping
# from brms model parameters to marginal means. I would expect the
# emmeans/brms integration to agree with the results below
# (within rounding error + MCMC error), based on what I find with lm()
# (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
  make_standata(data = data) %>%
  .subset2("X") %>%
  colMeans() %>%
  t()
grid <- data %>%
  mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
  distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model_brms %>%
  as_draws_df() %>%
  as_tibble() %>%
  select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
    FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
  ) %>%
  make_standata(data = grid) %>%
  .subset2("X") %>%
  bind_cols(proportional_factors) %>%
  setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
  as.data.frame() %>%
  as_tibble()
summary_brms_custom <- draws_custom %>%
  pivot_longer(everything()) %>%
  separate("name", c("ARMCD", "AVISIT")) %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "3_brms_custom",
    mean = mean(value),
    lower = quantile(value, 0.025),
    upper = quantile(value, 0.975),
    .groups = "drop"
  )

# Compare results
summary <- bind_rows(
  summary_data,
  summary_lm_emmeans,
  summary_brms_custom,
  summary_brms_emmeans
)
ggplot(summary) +
  geom_point(aes(x = source, y = mean, color = source)) +
  geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
  facet_grid(ARMCD ~ AVISIT) +
  theme_gray(16) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
  ylab("FEV1_CHG")

Screenshot 2024-05-15 at 2 27 51 PM

sessionInfo()
#> R version 4.4.0 (2024-04-24)
#> Platform: aarch64-apple-darwin20
#> Running under: macOS Sonoma 14.5
#>
#> Matrix products: default
#> BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> time zone: America/Indiana/Indianapolis
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods
#> [7] base
#>
#> other attached packages:
#>  [1] zoo_1.8-12       lubridate_1.9.3  forcats_1.0.0
#>  [4] stringr_1.5.1    dplyr_1.1.4      purrr_1.0.2
#>  [7] readr_2.1.5      tidyr_1.3.1      tibble_3.2.1
#> [10] ggplot2_3.5.1    tidyverse_2.0.0  posterior_1.5.0
#> [13] mmrm_0.3.11      emmeans_1.10.1   coda_0.19-4.1
#> [16] brms_2.21.3      Rcpp_1.0.12      abind_1.4-5
#> [19] drake_7.13.10    testthat_3.2.1.1
#>
#> loaded via a namespace (and not attached):
#>   [1] Rdpack_2.6           txtq_0.2.4           gridExtra_2.3
#>   [4] remotes_2.5.0        inline_0.3.19        rlang_1.1.3
#>   [7] magrittr_2.0.3       matrixStats_1.3.0    compiler_4.4.0
#>  [10] loo_2.7.0            callr_3.7.6          vctrs_0.6.5
#>  [13] profvis_0.3.8        pkgconfig_2.0.3      crayon_1.5.2
#>  [16] fastmap_1.1.1        backports_1.4.1      ellipsis_0.3.2
#>  [19] utf8_1.2.4           promises_1.3.0       tzdb_0.4.0
#>  [22] sessioninfo_1.2.2    ps_1.7.6             waldo_0.5.2
#>  [25] cachem_1.0.8         jsonlite_1.8.8       progress_1.2.3
#>  [28] later_1.3.2          parallel_4.4.0       prettyunits_1.2.0
#>  [31] R6_2.5.1             StanHeaders_2.32.7   stringi_1.8.4
#>  [34] parallelly_1.37.1    pkgload_1.3.4        estimability_1.5
#>  [37] brio_1.1.5           bindr_0.1.1          rstan_2.32.6
#>  [40] usethis_2.2.3        bayesplot_1.11.1     httpuv_1.6.15
#>  [43] Matrix_1.7-0         igraph_2.0.3         timechange_0.3.0
#>  [46] tidyselect_1.2.1     rstudioapi_0.16.0    codetools_0.2-20
#>  [49] miniUI_0.1.1.1       curl_5.2.1           processx_3.8.4
#>  [52] listenv_0.9.1        pkgbuild_1.4.4       lattice_0.22-6
#>  [55] shiny_1.8.1.1        withr_3.0.0          bridgesampling_1.1-2
#>  [58] future_1.33.2        desc_1.4.3           RcppParallel_5.1.7
#>  [61] urlchecker_1.0.1     pillar_1.9.0         fstcore_0.9.18
#>  [64] filelock_1.0.3       tensorA_0.36.2.1     checkmate_2.3.1
#>  [67] renv_1.0.7           stats4_4.4.0         distributional_0.4.0
#>  [70] generics_0.1.3       rprojroot_2.0.4      hms_1.1.3
#>  [73] rstantools_2.4.0     munsell_0.5.1        scales_1.3.0
#>  [76] storr_1.2.5          globals_0.16.3       xtable_1.8-4
#>  [79] base64url_1.4        glue_1.7.0           tools_4.4.0
#>  [82] data.table_1.15.4    fs_1.6.4             mvtnorm_1.2-4
#>  [85] grid_4.4.0           rbibutils_2.2.16     QuickJSR_1.1.3
#>  [88] devtools_2.4.5       colorspace_2.1-0     nlme_3.1-164
#>  [91] cli_3.6.2            fst_0.9.8            fansi_1.0.6
#>  [94] Brobdingnag_1.2-9    V8_4.4.2             gtable_0.3.5
#>  [97] digest_0.6.35        htmlwidgets_1.6.4    memoise_2.0.1
#> [100] htmltools_0.5.8.1    lifecycle_1.0.4      mime_0.12
paul-buerkner commented 4 months ago

Thank you for reporting this issue. I am no emmeans expert so for me it's hard to tell what is going on. @rvlenth do you happen to have an idea perhaps?

rvlenth commented 4 months ago

I have no clue.

I am bothered by the fact that there are two (very) different objects named model in this code.

As for the "custom" code, I disagree that it is what emmeans should be doing, simply because whatever all that stuff is, it shouldn't be that complex.

My suggestion for finding out more is to try this, using the second version of model, the one that was produced by brm().

emm_itself <- emmeans(
  object = model,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)

summary(emm_itself)

So far, we are now seeing directly what emmeans is producing. Are the estimates the same as those in the plot? Do the annotations below the summary provide additional information that was never seen because it was swept away by all the "tidy" post-processing? Because the estimate in the summary is the median of the posterior, how about the results in summary(emm_itself, point.est = mean)?

If you still see the serious discrepancies, do this:

newdata <-emmeans::ref_grid(model)@grid

This gives you the grid of all fixed-effects factors, which is the basis for all emmeans calculations. Then use brms functions/methods to obtain predictions from model, with newdata as new data. Average those results together over all but the two primary factors, using appropriate weights. That's what emmeans should be doing.

wlandau commented 4 months ago

I am bothered by the fact that there are two (very) different objects named model in this code.

Edited https://github.com/paul-buerkner/brms/issues/1654#issue-2298560967 to use model_lm and model_brms.

As for the "custom" code, I disagree that it is what emmeans should be doing, simply because whatever all that stuff is, it shouldn't be that complex.

Edited https://github.com/paul-buerkner/brms/issues/1654#issue-2298560967 to clarify that comment.

So far, we are now seeing directly what emmeans is producing. Are the estimates the same as those in the plot?

Yes:

emm_itself <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)
summary(emm_itself)
#>  ARMCD AVISIT emmean lower.HPD upper.HPD
#>  PBO   VIS1   -18.08    -22.29   -13.617
#>  TRT   VIS1   -14.81    -18.28   -11.236
#>  PBO   VIS2   -16.08    -19.48   -12.372
#>  TRT   VIS2   -12.68    -16.50    -9.329
#>  PBO   VIS3   -12.53    -15.84    -8.909
#>  TRT   VIS3    -9.71    -13.31    -6.333
#>  PBO   VIS4    -7.93    -11.54    -4.363
#>  TRT   VIS4    -3.46     -7.11     0.102
#> 
#> Results are averaged over the levels of: 2 nuisance factors 
#> Point estimate displayed: median 
#> HPD interval probability: 0.95 
as.data.frame(summary_brms_emmeans)
#>   ARMCD AVISIT       mean      lower       upper         source
#> 1   PBO   VIS1 -18.083219 -22.287808 -13.6172246 4_brms_emmeans
#> 2   TRT   VIS1 -14.812490 -18.276953 -11.2362895 4_brms_emmeans
#> 3   PBO   VIS2 -16.079485 -19.477840 -12.3717137 4_brms_emmeans
#> 4   TRT   VIS2 -12.679113 -16.503203  -9.3292318 4_brms_emmeans
#> 5   PBO   VIS3 -12.527884 -15.841525  -8.9088424 4_brms_emmeans
#> 6   TRT   VIS3  -9.709981 -13.307893  -6.3334955 4_brms_emmeans
#> 7   PBO   VIS4  -7.928348 -11.537075  -4.3630501 4_brms_emmeans
#> 8   TRT   VIS4  -3.462008  -7.109919   0.1019503 4_brms_emmeans
summary(emm_itself)$emmean - summary_brms_emmeans$mean
#> [1] 0 0 0 0 0 0 0 0
summary(emm_itself)$lower.HPD - summary_brms_emmeans$lower
#> [1] 0 0 0 0 0 0 0 0
summary(emm_itself)$upper.HPD - summary_brms_emmeans$upper
#> [1] 0 0 0 0 0 0 0 0

Do the annotations below the summary provide additional information that was never seen because it was swept away by all the "tidy" post-processing?

The summary says the results are averaged over two nuisance variables, whereas the code supplies three. I am not sure why, or if it matters here. This makes sense because there are no fixed effects for USUBJID.

Because the estimate in the summary is the median of the posterior, how about the results in summary(emm_itself, point.est = mean)?

Only slight differences:

summary_emmeans <- summary(emm_itself, point.est = mean)
max(abs(summary_emmeans$emmean - summary_brms_emmeans$mean))
#> [1] 0.0202332

If you still see the serious discrepancies, do this:

newdata <-emmeans::ref_grid(model)@grid

This gives you the grid of all fixed-effects factors, which is the basis for all emmeans calculations. Then use brms functions/methods to obtain predictions from model, with newdata as new data. Average those results together over all but the two primary factors, using appropriate weights.

When I do that, I see close enough agreement with the native lm()/emmeans integration, but strong disagreement between the brms/emmeans integration.

# Predictions
new_data <- emmeans::ref_grid(model_brms)@grid
predictions <- predict(model_brms, newdata = new_data)
grid <- mutate(new_data, estimate = predictions[, "Estimate"])

# Proportional weights
weighted_grid <- grid %>%
  left_join(y = count(data, RACE, SEX), by = c("RACE", "SEX")) %>%
  rename(.wgt. = n)

# Marginal means
custom <- weighted_grid %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(mean = sum(estimate * .wgt.) / sum(.wgt.)) %>%
  arrange(AVISIT, ARMCD)
custom
#> # A tibble: 8 × 3
#> # Groups:   ARMCD [2]
#>   ARMCD AVISIT   mean
#>   <fct> <fct>   <dbl>
#> 1 PBO   VIS1   -4.67 
#> 2 TRT   VIS1   -1.24 
#> 3 PBO   VIS2   -2.47 
#> 4 TRT   VIS2    0.957
#> 5 PBO   VIS3    1.00 
#> 6 TRT   VIS3    3.78 
#> 7 PBO   VIS4    5.57 
#> 8 TRT   VIS4   10.1  

# Good enough agreement with lm marginal means
summary_lm_emmeans
#> # A tibble: 8 × 6
#>   ARMCD AVISIT   mean  lower  upper source      
#>   <fct> <fct>   <dbl>  <dbl>  <dbl> <chr>       
#> 1 PBO   VIS1   -4.60  -5.98  -3.22  2_lm_emmeans
#> 2 TRT   VIS1   -1.29  -2.76   0.185 2_lm_emmeans
#> 3 PBO   VIS2   -2.54  -3.92  -1.17  2_lm_emmeans
#> 4 TRT   VIS2    0.847 -0.625  2.32  2_lm_emmeans
#> 5 PBO   VIS3    0.984 -0.393  2.36  2_lm_emmeans
#> 6 TRT   VIS3    3.80   2.33   5.27  2_lm_emmeans
#> 7 PBO   VIS4    5.60   4.22   6.98  2_lm_emmeans
#> 8 TRT   VIS4   10.1    8.58  11.5   2_lm_emmeans

max(abs(custom$mean - summary_lm_emmeans$mean))
#> [1] 0.1104108

# Disagreement with the native emmeans/brms integration
max(abs(custom$mean - summary_brms_emmeans$mean))
#> [1] 13.63619
wlandau commented 4 months ago

Also, thanks for explaining the role of emmeans::ref_grid(model_brms)@grid in the weighting technique. This object is basically an expand.grid() over the unique levels of all the factors in the fixed effects, including nuisance factors, with continuous variables set at their observed grand means. Each row in the grid is given a weight, and I guess these weights are used to estimate marginal means as weighted averages over rows of predicted responses in the grid. This is the most direct and edifying explanation I have seen about how exactly the reference grid works and what exactly we mean by a "weight" in emmeans. (I read the help files, https://www.jstatsoft.org/article/view/v069i01, and all the vignettes, but I still missed these concepts.) Very helpful.

But whether we take the emmeans the two-step approach of predict() + weighting, or we use my reprex's one-step linear transformation from model coefficients to marginal means, the results appear to agree on the frequentist model.

# Create the reference grid.
new_data <- emmeans::ref_grid(model_lm)@grid
grid <- mutate(new_data, estimate = predict(model_lm, newdata = new_data))

# Apply proportional weights.
weighted_grid <- grid %>%
  left_join(y = count(data, RACE, SEX), by = c("RACE", "SEX")) %>%
  mutate(.wgt. = n)

# Compute marginal means using the weighted grid.
summary_lm_emmeans_using_grid <- weighted_grid %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(mean = sum(estimate * .wgt.) / sum(.wgt.)) %>%
  arrange(AVISIT, ARMCD)

# Both approaches agree:
max(abs(summary_lm_emmeans_using_grid$mean - summary_lm_emmeans$mean))
#> [1] 5.329071e-15
rvlenth commented 4 months ago

We can go all over the place looking at examples and trying to guess what is done, but it shouldn't be too difficult to tell by looking at the code.

The emmeans package provides the infrastructure, but what it does to actually estimate things depends on the emm_basis method for that model class, and in this case that method is part of the package code for brms. Here is that code, copied here for convenience:

> brms:::emm_basis.brmsfit

function (object, trms, xlev, grid, vcov., resp = NULL, dpar = NULL, 
    nlpar = NULL, re_formula = NA, epred = FALSE, ...) 
{
    if (is_equal(dpar, "mean")) {
        warning2("dpar = 'mean' is deprecated. Please use epred = TRUE instead.")
        epred <- TRUE
        dpar <- NULL
    }
    epred <- as_one_logical(epred)
    bterms <- .extract_par_terms(object, resp = resp, dpar = dpar, 
        nlpar = nlpar, re_formula = re_formula, epred = epred)
    if (epred) {
        post.beta <- posterior_epred(object, newdata = grid, 
            re_formula = re_formula, resp = resp, incl_autocor = FALSE, 
            ...)
    }
    else {
        req_vars <- all_vars(bterms$allvars)
        post.beta <- posterior_linpred(object, newdata = grid, 
            re_formula = re_formula, resp = resp, dpar = dpar, 
            nlpar = nlpar, incl_autocor = FALSE, req_vars = req_vars, 
            transform = FALSE, offset = FALSE, ...)
    }
    if (anyNA(post.beta)) {
        stop2("emm_basis.brmsfit created NAs. Please check your reference grid.")
    }
    misc <- bterms$.misc
    if (length(dim(post.beta)) == 3L) {
        ynames <- dimnames(post.beta)[[3]]
        if (is.null(ynames)) {
            ynames <- as.character(seq_len(dim(post.beta)[3]))
        }
        dims <- dim(post.beta)
        post.beta <- matrix(post.beta, ncol = prod(dims[2:3]))
        misc$ylevs = list(rep.meas = ynames)
    }
    attr(post.beta, "n.chains") <- object$fit@sim$chains
    X <- diag(ncol(post.beta))
    bhat <- apply(post.beta, 2, mean)
    V <- cov(post.beta)
    nbasis <- matrix(NA)
    dfargs <- list()
    dffun <- function(k, dfargs) Inf
    environment(dffun) <- baseenv()
    nlist(X, bhat, nbasis, V, dffun, dfargs, misc, post.beta)
}

In the arguments, object is the model object, trms is a terms component, and grid is a data frame with the factor combinations in the reference grid. The function is supposed to set us up to produce predictions with grid as new data. In the returned list, bhat is the regression coefficients and X is the matrix of linear functions such that X %*% bhat obtains the predictions. (More important post.beta is the posterior sample for bhat.) This particular function has a few optional brmsfit-specific arguments resp,dpar,nlpar,re_formula, epred which - as this isn't my package - I am in no position to explain, but they affect how things get set up. Some of them are mentioned in the help for predict.brmsfit.

This is not a very complex function (seems simpler than a lot of the code in this issue), and I suggets trying to understand what it does. For example, maybe what you need to do is add the argument epred = TRUE?

rvlenth commented 4 months ago

@wlandau PS -- of course, you should also look at ? emm_basis.brmsfit

bjoernholzhauer commented 2 months ago

There somehow does not seem to be a problem when you use the approach via as.mcmc (which arguably should be what you do by default):

summary_brms_mcmc <- model_brms %>%
  emmeans( ~ ARMCD | AVISIT, weights = "proportional") %>%
  as.mcmc() %>%
  summarise_draws(~quantile(.x, probs = c(0.5, 0.025, 0.975)), mean) %>%
  mutate(source="5_brms_mcmc",
         ARMCD = str_extract(variable, "PBO|TRT"),
         AVISIT = str_extract(variable, "VIS[0-9]+")) %>%
  rename(lower=`2.5%`, upper=`97.5%`, median=`50%`)

# Compare results
summary <- bind_rows(
  summary_data,
  summary_lm_emmeans,
  summary_brms_custom,
  summary_brms_emmeans,
  summary_brms_mcmc
)
ggplot(summary) +
  geom_point(aes(x = source, y = mean, color = source)) +
  geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
  facet_grid(ARMCD ~ AVISIT) +
  theme_gray(16) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
  ylab("FEV1_CHG")

produces the expected result: image

wlandau commented 2 months ago

Ah, that's helpful. So it seems the underlying derived posterior samples of marginal means are correct, but the internal summarization of those samples in the emmeans/brms plugin is somehow incorrect. That's good enough for a partial workaround, but still misses much of the convenience because it requires dealing with non-machine-readable column names (as you did with str_extract()), and this is hard to automate reliably in the general case.

Was the original intention to support as.data.frame(emmeans()) for brms fitted models? If not, an informative error at least would be nice.

andrew-bean commented 2 months ago

I did some digging and I think I may be close to pinpointing the issue. It does not relate to whether or not one uses as.mcmc.emmGrid, but rather to the presence or absence of the nuisance argument in the call to ref_grid().

Note that @bjoernholzhauer dropped this argument, hence the issue does not affect his code. However if you replace his emmeans() call with @wlandau 's that has the nuisance argument, you will see the issue persists regardless of whether or not as.mcmc is invoked.

If you compare emm_basis.brmsfit (copied above by @rvlenth) to emm_basis.lm, you'll see they take very different approaches. emm_basis.lm returns bhat as estimates of the regression coefficients and X as the model matrix corresponding to the reference grid. However in emm_basis.brmsfit, bhat and beta.post are already on the response scale (posterior_linpred does the leg work for this), and X is just a diagonal matrix. This is fine when there are no nuisance factors.

However when nuisance factors are present, a subsequent function emmeans:::.basis.nuis() called within emmeans::ref_grid makes an adjustment to the bases to append the means of all levels of the nuisance factors. This function plays well with the lm setup, but does NOT play well with the bases returned by brms:::emm_basis.brmsfit. Here is an illustration of where things go awry


# set up ref_grid under both lm and brms
rg_lm <- ref_grid(
  object = model_lm,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)
rg_brms <- ref_grid(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
)

# the lm method takes linfct as the design matrix corresponding to the reference grid
rg_lm@linfct

# and the bhat is just coef(model_lm)
rg_lm@bhat
unname(coef(model_lm))

# brms handles this very differently:
rg_brms@bhat # this is the same as posterior_linpred() at some reference grid levels
rg_brms@linfct # this does not appear to be a meaningful set of linear combination of those response means

# reproduction via package internals -------------------------------------------

# initial grid set up in emmeans::ref_grid() in case of nuisance factors
rl <- list(
  FEV1_BL = 40.12532,
  ARMCD = c("PBO", "TRT"),
  AVISIT = c("VIS1", "VIS2", "VIS3", "VIS4"),
  RACE = c("Asian", "Black or African American", "White"),
  SEX = c("Male", "Female"),
  WEIGHT = 0.5175461
)

# this function returns an initial reference grid
nuis.info_brms = emmeans:::.setup.nuis(
  nuis = c("USUBJID", "RACE", "SEX"),
  levs = rl, 
  trms = attr(model_brms$data, "terms"),
  rg.limit = 10000
)
grid_brms <- nuis.info_brms$grid
grid_brms

nuis.info_lm = emmeans:::.setup.nuis(
  nuis = c("USUBJID", "RACE", "SEX"),
  levs = rl, 
  trms = attr(emmeans:::recover_data(model_lm), "terms"),
  rg.limit = 10000
)
grid_lm <- nuis.info_lm$grid
grid_lm

# no problems so far
identical(grid_brms, grid_lm) # TRUE

# next the key adapter function is called using this initial grid
basis_brms <- brms:::emm_basis.brmsfit(
  model_brms,
  trms = attr(emmeans:::recover_data(model_brms), "terms"),
  xlev = rl[c("ARMCD", "AVISIT", "RACE", "SEX")],
  grid = grid_brms
)

basis_lm <- emmeans:::emm_basis.lm(
  model_lm,
  trms = attr(emmeans:::recover_data(model_lm), "terms"),
  xlev = rl[c("AVISIT", "ARMCD", "RACE", "SEX")],
  grid = grid_lm
)

# still no problems yet
m_brms <- basis_brms$X %*% basis_brms$bhat
m_lm <- basis_lm$X %*% basis_lm$bhat
m_brms
m_lm
max(m_brms - m_lm)
# 0.02961797

# This is where things go awry for brms ----------------------------------------

basis_nuis_brms <- emmeans:::.basis.nuis(
  basis_brms,
  nuis.info_brms,
  "proportional",
  rl,
  emmeans:::recover_data(model_brms),
  grid = grid_brms,
  rl
)

# after this function ...
basis_nuis_brms$bhat #  [1] -6.759959 -3.409183 -4.683985 -1.301822 -1.174313  1.648844  3.451389  7.900958 -6.759959 -6.028935 -1.322218 -6.759959 -5.989218
# is still the same as
colMeans(posterior_linpred(model_brms, newdata = grid_brms)) #  [1] -6.759959 -3.409183 -4.683985 -1.301822 -1.174313  1.648844  3.451389  7.900958 -6.759959 -6.028935 -1.322218 -6.759959 -5.989218
# i.e. it contains (posterior mean) point estimates of the mean response at
# these grid levels
kable(grid_brms)
# |  FEV1_BL|ARMCD |AVISIT |RACE                      |SEX    |    WEIGHT|
# |--------:|:-----|:------|:-------------------------|:------|---------:|
# | 40.12532|PBO   |VIS1   |Asian                     |Male   | 0.5175461|
# | 40.12532|TRT   |VIS1   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS2   |Asian                     |Male   | 0.5175461|
# | 40.12532|TRT   |VIS2   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS3   |Asian                     |Male   | 0.5175461|
# | 40.12532|TRT   |VIS3   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS4   |Asian                     |Male   | 0.5175461|
# | 40.12532|TRT   |VIS4   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS1   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS1   |Black or African American |Male   | 0.5175461|
# | 40.12532|PBO   |VIS1   |White                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS1   |Asian                     |Male   | 0.5175461|
# | 40.12532|PBO   |VIS1   |Asian                     |Female | 0.5175461|

# likewise post.beta
colMeans(basis_nuis_brms$post.beta)
# contains the posterior draws of the response at those levels

# and the "X" slot has been transformed from
basis_brms$X
#       [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
# [1,]     1    0    0    0    0    0    0    0    0     0     0     0     0
# [2,]     0    1    0    0    0    0    0    0    0     0     0     0     0
# [3,]     0    0    1    0    0    0    0    0    0     0     0     0     0
# [4,]     0    0    0    1    0    0    0    0    0     0     0     0     0
# [5,]     0    0    0    0    1    0    0    0    0     0     0     0     0
# [6,]     0    0    0    0    0    1    0    0    0     0     0     0     0
# [7,]     0    0    0    0    0    0    1    0    0     0     0     0     0
# [8,]     0    0    0    0    0    0    0    1    0     0     0     0     0
# [9,]     0    0    0    0    0    0    0    0    1     0     0     0     0
# [10,]    0    0    0    0    0    0    0    0    0     1     0     0     0
# [11,]    0    0    0    0    0    0    0    0    0     0     1     0     0
# [12,]    0    0    0    0    0    0    0    0    0     0     0     1     0
# [13,]    0    0    0    0    0    0    0    0    0     0     0     0     1
# to
basis_nuis_brms$X 
# [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]      [,9]     [,10]     [,11]     [,12]     [,13]
# [1,]    1    0    0    0    0    0    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [2,]    0    1    0    0    0    0    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [3,]    0    0    1    0    0    0    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [4,]    0    0    0    1    0    0    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [5,]    0    0    0    0    1    0    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [6,]    0    0    0    0    0    1    0    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [7,]    0    0    0    0    0    0    1    0 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949
# [8,]    0    0    0    0    0    0    0    1 0.3553299 0.3756345 0.2690355 0.4670051 0.5329949

# the result of premultiplying bhat by X is no longer meaningful
basis_nuis_brms$X %*% basis_nuis_brms$bhat 
# [1,] -18.131533
# [2,] -14.780756
# [3,] -16.055559
# [4,] -12.673396
# [5,] -12.545887
# [6,]  -9.722730
# [7,]  -7.920185
# [8,]  -3.470616
# ^ this is where the odd point estimates in Will's code come from
# these are odd linear combinations of the mean response at the initial grid_brms,

# here's how things work correctly with lm's

basis_nuis_lm <- emmeans:::.basis.nuis(
  basis_lm,
  nuis.info_lm,
  "proportional",
  rl,
  emmeans:::recover_data(model_lm),
  grid = grid_lm,
  rl
)
basis_nuis_lm$X # correct contrast matrix for the marginal means
basis_nuis_lm$bhat # estimates of the regression coefficients
andrew-bean commented 2 months ago

Long story short, it seems like emm_basis.brmsfit may need some retooling to play nice with ref_grid() when nuisance factors are present. I won't speculate on how this can be done robustly given the massive flexibility of brm(). But in the case of linear models without random effects, this hack seems to work (having emm_basis.brmsfit return inference on the regression coefficients rather than response means, and using the same contrast matrix as emm_basis.lm)

# hacky patch 
rg_brms_patch <- rg_brms
rg_brms_patch@linfct <- rg_lm@linfct
rg_brms_patch@bhat <- fixef(model_brms)[, "Estimate"]
rg_brms_patch@post.beta <- as.matrix(model_brms)[, 1:length(coef(model_lm))]

emmeans(rg_brms, specs = ~ARMCD:AVISIT)
emmeans(rg_brms_patch, specs = ~ARMCD:AVISIT)
emmeans(rg_lm, specs = ~ARMCD:AVISIT)
summary_brms_custom
rvlenth commented 2 months ago

Thanks for working through this so carefully. Now I'm concerned about how nuisance variables are handled in several other situations where there is implicit re-gridding to the response or some other scale. Unfortunately I am traveling and cannot look at this for at least a week. But I will get to t when I can.

Russ

rvlenth commented 2 months ago

Thanks to @andrew-bean for pointing me in the right direction regarding the nuisance factors. I found a typo in emmeans:::.basis.nuis() which causes the wrong grid to be returned when there is a multivariate response. It is in line 1129 of ref-grid.R:

    basis$grid = grid[RA == ".main.grid.", , drop = FALSE]

which should be

    basis$grid = grid[ra == ".main.grid.", , drop = FALSE]

I think that you will get much saner results when nuisance is used, once this bugfix is pushed up and emmeans is reinstalled. In the meantime, avoid using nuisance variables when there is a multivariate response.

My apologies for the error, but my thanks for discovering this!

paul-buerkner commented 2 months ago

Thank you all for helping to figure this out!

wlandau commented 2 months ago

Sorry if I am missing something, but when I install https://github.com/rvlenth/emmeans/commit/909821dd40a68ecb43713cbd8ff3a3911e15b70a and run the original reprex in https://github.com/paul-buerkner/brms/issues/1654#issue-2298560967, I still see the same plot. Session info:

R version 4.4.0 (2024-04-24) Platform: aarch64-apple-darwin20 Running under: macOS Sonoma 14.5 Matrix products: default BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0 locale: [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8 time zone: America/Indiana/Indianapolis tzcode source: internal attached base packages: [1] stats graphics grDevices utils datasets methods base other attached packages: [1] zoo_1.8-12 lubridate_1.9.3 forcats_1.0.0 stringr_1.5.1 [5] dplyr_1.1.4 purrr_1.0.2 readr_2.1.5 tidyr_1.3.1 [9] tibble_3.2.1 ggplot2_3.5.1 tidyverse_2.0.0 posterior_1.6.0 [13] mmrm_0.3.12.9000 emmeans_1.10.3-090004 coda_0.19-4.1 brms_2.21.0 [17] Rcpp_1.0.13 loaded via a namespace (and not attached): [1] gtable_0.3.5 tensorA_0.36.2.1 QuickJSR_1.3.1 processx_3.8.4 [5] inline_0.3.19 lattice_0.22-6 callr_3.7.6 tzdb_0.4.0 [9] ps_1.7.7 vctrs_0.6.5 tools_4.4.0 Rdpack_2.6 [13] generics_0.1.3 stats4_4.4.0 curl_5.2.1 parallel_4.4.0 [17] sandwich_3.1-0 fansi_1.0.6 pkgconfig_2.0.3 Matrix_1.7-0 [21] checkmate_2.3.1 distributional_0.4.0 RcppParallel_5.1.8 lifecycle_1.0.4 [25] farver_2.1.2 compiler_4.4.0 textshaping_0.4.0 Brobdingnag_1.2-9 [29] munsell_0.5.1 codetools_0.2-20 bayesplot_1.11.1 pillar_1.9.0 [33] MASS_7.3-61 StanHeaders_2.32.10 bridgesampling_1.1-2 abind_1.4-5 [37] multcomp_1.4-26 nlme_3.1-165 rstan_2.32.6 tidyselect_1.2.1 [41] mvtnorm_1.2-5 stringi_1.8.4 labeling_0.4.3 splines_4.4.0 [45] grid_4.4.0 colorspace_2.1-0 cli_3.6.3 magrittr_2.0.3 [49] loo_2.8.0 survival_3.7-0 pkgbuild_1.4.4 utf8_1.2.4 [53] TH.data_1.1-2 withr_3.0.0 scales_1.3.0 backports_1.5.0 [57] timechange_0.3.0 estimability_1.5.1 matrixStats_1.3.0 gridExtra_2.3 [61] ragg_1.3.2 hms_1.1.3 rbibutils_2.2.16 V8_4.4.2 [65] rstantools_2.4.0 rlang_1.1.4 xtable_1.8-4 glue_1.7.0 [69] rstudioapi_0.16.0 jsonlite_1.8.8 R6_2.5.1 systemfonts_1.1.0
rvlenth commented 2 months ago

I don't know. The fix I put in is for a multivariate response when the nuisance argument is used, as per @andrew-bean 's comment. Is that the case? I'm really not following this issue very well, and could easily believe that there are several balls in the air. Back in May when this issue was first created, I was not able to get the brms code to run, so am of limited help. I guess there's some key missing piece in my rstan setup or something.

rvlenth commented 2 months ago

There's some other discussion up there about using as.mcmc(). I surmise that with my recent fix, this as.mcmc() approach should now work the same whether or not you specified nuisance factors. I think that's different than just re-running the original plot code.

I also see a remark about as.data.frame(emmeans(...)) which from the code I think is meant to refer to as.data.frame(summary(emmeans(...))). My comments here are:

  1. It shouldn't be necessary to use as.data.frame because summary.emmGrid already produces an object that inherits from data.frame.
  2. The summary.emmGrid() function results include annotations that explain a few things about what is being summarized. Piping that summary into further steps may suppress those annotations, so you never get to see them. And with the newest version of emmeans, there is an on-load message that warns about this. You might have a better understanding of what you have if you actually read those annotations instead of suppressing them. This seems especially pertinent in that the plot seems to show results that are on an entirely different scale than the other results.
rvlenth commented 2 months ago

OK, I finally managed to fit the brms model (I had to unset the brms.backend = "rcmdstan" option). And here is what I get directly from emmeans():

### Your emmeans code ###
> model_emm <- emmeans(
+     object = model_brms,
+     specs = ~ARMCD:AVISIT,
+     wt.nuis = "proportional",
+     nuisance = c("USUBJID", "RACE", "SEX")
+ )
> model_emm
 ARMCD AVISIT emmean lower.HPD upper.HPD
 PBO   VIS1   -18.14    -22.55   -13.622
 TRT   VIS1   -14.78    -18.66   -11.466
 PBO   VIS2   -16.05    -19.63   -12.620
 TRT   VIS2   -12.70    -16.03    -8.974
 PBO   VIS3   -12.53    -16.15    -9.122
 TRT   VIS3    -9.77    -13.16    -5.967
 PBO   VIS4    -7.93    -11.53    -4.480
 TRT   VIS4    -3.49     -6.92     0.186

Results are averaged over the levels of: 2 nuisance factors 
Point estimate displayed: median 
HPD interval probability: 0.95 

These results do appear similar to what is in the plot.

If I bypass the nuisance stuff, I get different results that are more similar to other estimates in the plot

> emmeans(
+     object = model_brms,
+     specs = ~ARMCD:AVISIT, 
+     weights = "prop")
 ARMCD AVISIT emmean lower.HPD upper.HPD
 PBO   VIS1   -4.609    -5.905     -3.19
 TRT   VIS1   -1.269    -2.752      0.18
 PBO   VIS2   -2.528    -3.899     -1.18
 TRT   VIS2    0.844    -0.554      2.38
 PBO   VIS3    0.970    -0.454      2.20
 TRT   VIS3    3.779     2.178      5.13
 PBO   VIS4    5.604     4.234      6.96
 TRT   VIS4   10.057     8.623     11.55

Results are averaged over the levels of: RACE, SEX 
Point estimate displayed: median 
HPD interval probability: 0.95 

BTW, notice that the factor USUBJID is not considered in the reference grid and so it was unnecessary to specify it. Note that the point estimate used is the median. If we look at the mean instead, there isn't a whole lot of difference:

> emmeans(model_brms, ~ARMCD:AVISIT, weights = "prop") |> summary(point.est = mean)
 ARMCD AVISIT emmean lower.HPD upper.HPD
 PBO   VIS1   -4.603    -5.905     -3.19
 TRT   VIS1   -1.267    -2.752      0.18
 PBO   VIS2   -2.538    -3.899     -1.18
 TRT   VIS2    0.853    -0.554      2.38
 PBO   VIS3    0.980    -0.454      2.20
 TRT   VIS3    3.778     2.178      5.13
 PBO   VIS4    5.609     4.234      6.96
 TRT   VIS4   10.049     8.623     11.55

Results are averaged over the levels of: RACE, SEX 
Point estimate displayed: mean 
HPD interval probability: 0.95 

I'm not sure what's going on with the nuisance thing -- I'll look into it. But for the lm fit, I get identical results whether or not nuisance is used. And the estimates from model_lm and model_brms are very close:

> emmeans(
+     object = model_lm,
+     specs = ~ARMCD:AVISIT,
+     weights = "proportional")
 ARMCD AVISIT emmean    SE  df lower.CL upper.CL
 PBO   VIS1   -4.600 0.702 772   -5.977   -3.223
 TRT   VIS1   -1.286 0.749 772   -2.757    0.185
 PBO   VIS2   -2.545 0.702 772   -3.922   -1.167
 TRT   VIS2    0.847 0.749 772   -0.625    2.318
 PBO   VIS3    0.984 0.702 772   -0.393    2.361
 TRT   VIS3    3.801 0.750 772    2.329    5.273
 PBO   VIS4    5.601 0.701 772    4.225    6.978
 TRT   VIS4   10.052 0.750 772    8.580   11.524

Results are averaged over the levels of: RACE, SEX 
Confidence level used: 0.95 

Anyway, I guess for now you should still avoid the nuisance specification. Maybe there is another bug, or maybe there is just an explanation.

rvlenth commented 2 months ago

Oh, man -- I realize now what is going on. The coding for the nuisance option is based on some programming trickery, and part of it is that the columns of the model matrix are associated with the model terms. However, in the case of this model (and I think all brmsfit models) create a "re-gridded" basis, where the@linfct slot is just the identity matrix and the @post.beta slot consists of the sample of predictions on the reference grid. Thus the columns of @linfct are not associated with model terms and the nuisance code is not correct.

It's pretty easy to check for this, and so now we error out in this situation:

> emmeans(
+     object = model_lm,
+     specs = ~ARMCD:AVISIT,
+     wt.nuis = "proportional",
+     nuisance = c("RACE", "SEX"))

Error: Sorry, 'nuisance' specs are not allowed for this situation. Revise the call accordingly.

I apologize for all the grief this may have caused, but I'm glad that the problem has been identified.

wlandau commented 2 months ago

Thanks for finding out the origin of the issue!

I do see the new error message:

summary_brms_emmeans <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("RACE", "SEX")
)
#> Error: Sorry, 'nuisance' specs are not allowed for this situation. Revise the call accordingly.

At first I thought this might not bode well for Bayesian models with covariate adjustment, but then I realized that the nuisance argument is not actually necessary to generate a reference grid that only depends on ARMCD and AVISIT. Everything lines up nicely when I just drop nuisance. Pleasantly surprised that that worked.

suppressPackageStartupMessages({
  library(brms)
  library(coda)
  library(emmeans)
  library(mmrm)
  library(posterior)
  library(tidyverse)
  library(zoo)
})
emm_options(sep = "|")

packageDescription("brms")$GithubSHA1
#> [1] "298b947fa9cfb914aeb7cb3aab7974aa179682b1"

# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.)
data(fev_data, package = "mmrm")
data <- fev_data %>%
  mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
  select(-FEV1) %>%
  group_by(USUBJID) %>%
  complete(
    AVISIT,
    fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
  ) %>%
  ungroup() %>%
  arrange(USUBJID, AVISIT) %>%
  group_by(USUBJID) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
  ungroup() %>%
  filter(!is.na(FEV1_CHG))
summary_data <- data %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "1_data",
    mean = mean(FEV1_CHG),
    lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    .groups = "drop"
  )

# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
  AVISIT + RACE + SEX + WEIGHT

# lm with emmeans
model_lm <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
  object = model_lm,
  specs = ~ARMCD:AVISIT,
  weights = "proportional"
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
  rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
  mutate(source = "2_lm_emmeans")

# brms with emmeans
model_brms <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
  object = model_brms,
  specs = ~ARMCD:AVISIT,
  weights = "proportional"
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
  rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
  mutate(source = "4_brms_emmeans")

# custom marginal means from brms draws using a custom mapping
# from brms model parameters to marginal means. I would expect the
# emmeans/brms integration to agree with the results below
# (within rounding error + MCMC error), based on what I find with lm()
# (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
  make_standata(data = data) %>%
  .subset2("X") %>%
  colMeans() %>%
  t()
grid <- data %>%
  mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
  distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model_brms %>%
  as_draws_df() %>%
  as_tibble() %>%
  select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
    FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
  ) %>%
  make_standata(data = grid) %>%
  .subset2("X") %>%
  bind_cols(proportional_factors) %>%
  setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
  as.data.frame() %>%
  as_tibble()
summary_brms_custom <- draws_custom %>%
  pivot_longer(everything()) %>%
  separate("name", c("ARMCD", "AVISIT")) %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "3_brms_custom",
    mean = mean(value),
    lower = quantile(value, 0.025),
    upper = quantile(value, 0.975),
    .groups = "drop"
  )

# Compare results
max(abs(summary_brms_custom$mean - arrange(summary_brms_emmeans, ARMCD, AVISIT)$mean))
#> [1] 0.01103878

summary <- bind_rows(
  summary_data,
  summary_lm_emmeans,
  summary_brms_custom,
  summary_brms_emmeans
)
ggplot(summary) +
  geom_point(aes(x = source, y = mean, color = source)) +
  geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
  facet_grid(ARMCD ~ AVISIT) +
  theme_gray(16) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
  ylab("FEV1_CHG")

Screenshot 2024-07-22 at 4 10 29 PM

rvlenth commented 2 months ago

Exactly. In fact, nuisance factors are not supported for any brms models because they are all implemented in that regridded way. Unfortunately, this could lead to memory-use issues with big models and large numbers of MCMC runs; but there is no avoiding it.

Again, I'm sorry it took so long to figure this out. I was stymied by having an old backend option that no longer worked, keeping me from being able to fit the model.