vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
475 stars 47 forks source link

Very large CIs with brms mixed multinomial models #219

Closed DominiqueMakowski closed 2 years ago

DominiqueMakowski commented 2 years ago

I've been pulling my hair for weeks now in my own analysis using multinomial models because I have super mega large CIs. I thought there was something with my model or my data, or marginaleffects, but even recomputing the CIs myself led to the same result.

I finally managed to reproduce this issue. It turns out that the CIs become very large (as in off-the-roof) when the model is mixed

Non-mixed multinomial model

library(brms)
library(patchwork)
library(ggplot2)

data <- mtcars
data$cyl <- as.character(data$cyl)
data$vs <- as.character(data$vs)
data$carb <- as.character(data$carb)

model <- brm(cyl ~ mpg * vs, 
             data=data, 
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             family=categorical(link="logit", refcat = "4")) 

newdata <- insight::get_datagrid(model, at=c("mpg", "vs"), length=100, preserve_range = FALSE)

p1 <- modelbased::estimate_relation(model, at=c("mpg", "vs"), length=100, preserve_range = FALSE) |> 
  ggplot(aes(x = mpg, y = Predicted)) +
  geom_ribbon(aes(ymin = CI_low, ymax= CI_high, fill = Response, group = interaction(Response, vs)), alpha = 0.25) +
  geom_line(aes(color = Response, linetype = vs))

p2 <- marginaleffects::marginaleffects(model, variables = "mpg", newdata = newdata) |> 
  ggplot(aes(x = mpg, y = dydx)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_ribbon(aes(ymin=conf.low, ymax=conf.high, fill = group, group = interaction(group , vs)), alpha = 0.2) +
  geom_line(aes(color = group, linetype = vs))

p1 / p2

So far so good, now same thing with adding a random term:

Mixed multinomial


model <- brm(cyl ~ mpg * vs + (1|carb), 
             data=data, 
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             family=categorical(link="logit", refcat = "4")) 

newdata <- insight::get_datagrid(model, at=c("mpg", "vs"), length=100, preserve_range = FALSE)

p1 <- modelbased::estimate_relation(model, at=c("mpg", "vs"), length=100, preserve_range = FALSE) |> 
  ggplot(aes(x = mpg, y = Predicted)) +
  geom_ribbon(aes(ymin = CI_low, ymax= CI_high, fill = Response, group = interaction(Response, vs)), alpha = 0.25) +
  geom_line(aes(color = Response, linetype = vs))

p2 <- marginaleffects::marginaleffects(model, variables = "mpg", newdata = newdata) |> 
  ggplot(aes(x = mpg, y = dydx)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_ribbon(aes(ymin=conf.low, ymax=conf.high, fill = group, group = interaction(group , vs)), alpha = 0.2) +
  geom_line(aes(color = group, linetype = vs))

p1 / p2

Created on 2022-03-03 by the reprex package (v2.0.1)

Though the predicted plot looks about the same, the effects derivative CIs goes bonkers. Before moving the issue to brms, I first wanted to check here whether you might suspect it might be related to marginaleffects

DominiqueMakowski commented 2 years ago

Could it be a transformation issue 🤔

vincentarelbundock commented 2 years ago

I don't have a final answer, but I do have a theory. But before, I'll note that the intervals for the slopes should be quite large, but obviously not this large...

To calculate dydx in brms models we do this:

  1. predict the outcome with the baseline data
  2. predict the outcome with the baseline data incremented by 1e-5
  3. subtract draws obtained in 2 from draws obtained in 1, and divide by 1e-5
  4. compute median, quantiles, hdi, etc.

My current hunch is that some numerical precision issue is causing. Here's some code to implement steps 1-4 above. It's essentially a skeleton of what marginaleffects does behind the scene.

In this example, increment is set to 1. When I run it, the intervals are large, but not craaaaazy large. In contrast, if you set increment to 1e-5 instead, you'll see massive intervals.

library(brms)
library(ggplot2)

data <- mtcars
data$cyl <- as.character(data$cyl)
data$vs <- as.character(data$vs)
data$carb <- as.character(data$carb)
model <- brm(cyl ~ mpg * vs + (1|carb),
             data = data,
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             family = categorical(link = "logit", refcat = "4"))

increment <- 1
newdata0 <- newdata1 <- insight::get_datagrid(model, at = c("mpg", "vs"), length = 100, preserve_range = FALSE)
newdata1$mpg <- newdata0$mpg + increment

draws0 <- rstantools::posterior_epred(model, newdata = newdata0)
draws1 <- rstantools::posterior_epred(model, newdata = newdata1)
dydx <- (draws1 - draws0) / increment
dydx <- do.call("cbind", lapply(1:3, \(i) dydx[, , i]))
dydx <- t(apply(dydx, 2, function(x) c(mean(x), quantile(x, c(.0275, .975)))))
colnames(dydx) <- c("med", "conf.low", "conf.high")
dydx <- cbind(do.call("rbind", lapply(1:3, \(i) newdata0)), dydx)
dydx$response <- rep(c("4", "6", "8"), each = 200)

ggplot(dydx, aes(x = mpg,
                 y = med,
                 ymin = conf.low,
                 ymax = conf.high,
                 fill = factor(response),
                 linetype = factor(vs))) +
    geom_ribbon(alpha = .1) +
    geom_line(aes(color = factor(response)))
DominiqueMakowski commented 2 years ago

My current hunch is that some numerical precision issue is causing.

But if that was the case, shouldn't be the case for both non-mixed and mixed models?

vincentarelbundock commented 2 years ago

If you run the exact same code as above with a non-mixed model, the results look "reasonable". This suggests to me that something is happening in rstantools rather than marginaleffects, because there is nothing specific to my package in the minimal example shown above.

Question is: what is that "something" that is happening?

First possibility: there is almost perfect separation in that model. If you hist(draws0) you'll see it's almost all 0s and 1s. This may exacerbate any numerical instability.

Second hint (my favorite!!!): Your newdata data grid does not include all values of the random intercept variable carb. If you include them, the results look way more reasonable when calling marginaleffects() directly:

library(brms)
library(ggplot2)
library(marginaleffects)

data <- mtcars
data$cyl <- as.character(data$cyl)
data$vs <- as.character(data$vs)
data$carb <- as.character(data$carb)
model <- brm(cyl ~ mpg * vs + (1|carb),
             data = data,
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             family = categorical(link = "logit", refcat = "4"))

library(marginaleffects)
datplot = marginaleffects(model,
                          variables = "mpg",
                          newdata = datagrid(carb = data$carb,
                                             vs = 0:1,
                                             mpg = seq(min(data$mpg), max(data$mpg), length.out = 100)))

ggplot(datplot, aes(x = mpg, y = dydx)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_ribbon(aes(ymin=conf.low, ymax=conf.high, fill = group, group = interaction(group , vs)), alpha = 0.2) +
  geom_line(aes(color = group, linetype = factor(vs))) +
  facet_wrap(~carb)

carb

DominiqueMakowski commented 2 years ago

I'd lean in favour of 2nd theory too... I think it's worth to open an issue on brms. Do you have in mind the simplest reproducible example to illustrate the issue? I was thinking of something like that but it doesn't seem to work (I think I have missed the part where it goes wrong):

library(brms)
library(ggplot2)

data <- mtcars
data$cyl <- as.character(data$cyl)
data$vs <- as.character(data$vs)
data$carb <- as.character(data$carb)

model <- brm(cyl ~ mpg * vs + (1|carb),
             data = data,
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             refresh=0,
             family = categorical(link = "logit", refcat = "4"))

newdata0 <- newdata1 <- newdata2 <- insight::get_datagrid(model, at = c("mpg", "vs"), length = 100, preserve_range = FALSE)
head(newdata0)
#>        mpg vs carb
#> 1 10.40000  0   NA
#> 2 10.63737  0   NA
#> 3 10.87475  0   NA
#> 4 11.11212  0   NA
#> 5 11.34949  0   NA
#> 6 11.58687  0   NA
newdata1$mpg <- newdata0$mpg + 1
newdata2$mpg <- newdata0$mpg + 1e-5

draws0 <- rstantools::posterior_epred(model, newdata = newdata0)
draws1 <- rstantools::posterior_epred(model, newdata = newdata1)
draws2 <- rstantools::posterior_epred(model, newdata = newdata2)

quantile(draws1 - draws0, c(.0275, .975))
#>      2.75%      97.5% 
#> -0.5696543  0.6138213
quantile(draws2 - draws0, c(.0275, .975))
#>      2.75%      97.5% 
#> -0.2871597  0.3404077

Created on 2022-03-05 by the reprex package (v2.0.1)

vincentarelbundock commented 2 years ago

Here’s a minimal example. Note that I set carb = NA, as in the output of your get_datagrid() call.

library(brms)

data <- mtcars
data$cyl <- as.character(data$cyl)

model <- brm(cyl ~ mpg * vs + (1 | carb),
             data = data,
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             refresh=0,
             family = categorical(link = "logit", refcat = "4"))

newdata0 <- newdata1 <- data.frame(mpg = 20, vs = 0, carb = NA)
newdata1$mpg <- newdata1$mpg + 1e-5
pred0 <- posterior_epred(model, newdata = newdata0)
pred1 <- posterior_epred(model, newdata = newdata1)
dydx <- (pred1 - pred0) / 1e-5

apply(pred0[, 1, ], 2, quantile, probs = c(.025, .5, .975))
##                  4            6            8
## 2.5%  7.188320e-09 1.704124e-05 4.759571e-06
## 50%   7.844408e-06 7.035547e-01 2.964124e-01
## 97.5% 1.693925e-03 9.999902e-01 9.999779e-01
apply(pred1[, 1, ], 2, quantile, probs = c(.025, .5, .975))
##                  4            6            8
## 2.5%  7.189716e-09 1.764972e-05 2.813632e-06
## 50%   8.538644e-06 7.003035e-01 2.986317e-01
## 97.5% 2.001993e-03 9.999908e-01 9.999708e-01
apply(dydx[, 1, ], 2, quantile, probs = c(.025, .5, .975))
##                   4             6             8
## 2.5%  -7.652447e+01 -9.254042e+04 -8.943085e+04
## 50%    3.182602e-05  3.959162e-02 -6.129972e-02
## 97.5%  8.874331e+01  8.943082e+04  9.254030e+04
DominiqueMakowski commented 2 years ago

It works by adding re.form = NA :) We adjusted insight::get_predicted to make it compatible with all that (https://github.com/easystats/insight/pull/523)

Thanks @vincentarelbundock !

library(brms)
library(patchwork)
library(ggplot2)

data <- mtcars
data$cyl <- as.character(data$cyl)
data$vs <- as.character(data$vs)
data$carb <- as.character(data$carb)

model <- brm(cyl ~ mpg * vs + (1|carb), 
             data=data, 
             chains = 6,
             iter = 3000,
             algorithm = "meanfield",
             family=categorical(link="logit", refcat = "4")) 

newdata <- insight::get_datagrid(model, at=c("mpg", "vs"), length=100, preserve_range = FALSE)

p2 <- marginaleffects::marginaleffects(model, variables = "mpg", newdata = newdata, re.form = NA) |> 
  ggplot(aes(x = mpg, y = dydx)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_ribbon(aes(ymin=conf.low, ymax=conf.high, fill = group, group = interaction(group , vs)), alpha = 0.2) +
  geom_line(aes(color = group, linetype = vs))

p2

Created on 2022-03-08 by the reprex package (v2.0.1)

vincentarelbundock commented 2 years ago

Wow, obvious in retrospect!