vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
472 stars 47 forks source link

Change default `plot_predictions(re.form = NA)`? #725

Closed mattansb closed 1 year ago

mattansb commented 1 year ago

When plotting predictions from a mixed model with plot_predictions(), the default behavior is to plot predicted data for only the most frequent level of the random grouping variable(s).

library(marginaleffects)
library(ggplot2)
library(patchwork)

gm1 <- lme4::glmer(cbind(incidence, size - incidence) ~ period + (1 | herd),
                   data = lme4::cbpp, family = binomial)

p1 <- plot_predictions(gm1, condition = list("period")) + 
  coord_cartesian(ylim = c(0, 0.4))

# Is actually:
p2 <- plot_predictions(gm1, condition = list("period", herd = 1)) + 
  coord_cartesian(ylim = c(0, 0.4))

p1 | p2

However, I don’t think this is a good default. Users probably expect either the “typical” predictions (with re.form = NA), or some manner of average predictions across levels of the random grouping variable:

p_re.form_NA <- plot_predictions(gm1, condition = list("period"),
                                 re.form = NA) + 
  coord_cartesian(ylim = c(0, 0.4))

p_avg1 <- avg_predictions(gm1, variables = c(period = unique)) |> 
  ggplot(aes(period, estimate)) + 
  geom_pointrange(aes(ymin = conf.low, ymax = conf.high)) + 
  coord_cartesian(ylim = c(0, 0.4))

# Or:
p_avg2 <- avg_predictions(gm1, variables = c(period = unique, herd = unique),
                by = "period") |> 
  ggplot(aes(period, estimate)) + 
  geom_pointrange(aes(ymin = conf.low, ymax = conf.high)) + 
  coord_cartesian(ylim = c(0, 0.4))

p_re.form_NA | p_avg1 | p_avg2

Created on 2023-03-16 with reprex v2.0.2


Some options / ideas / thoughts / questions:

  1. Change re.form = NA to the default?
  2. Give a message / warning Hey, you might be interested in plotting the average predictions across random levels... with some options of how to do this?
  3. Why and how p_avg1 is different from p_avg2?
  4. Is there a way to get p_avg1/2 automagically?
vincentarelbundock commented 1 year ago

Thanks for raising this issue.

I pushed a commit to improve the documentation, and mention this potential problem at the very top of the help files for all plot_*() functions.

I’m not sure I want to make re.form=NA because it wouldn’t make sense as default for other functions like slopes() and predictions(). Also, since this is kind of a “hidden” setting – only available for a restricted number of models, and only by pushing to lme4::predict via ... – I think it would be confusing to have different defaults in different functions. So I’m hoping that clearer documentation will work here.

So far, the package never raises messages related to the interpretation of models. I think I like the idea of having good documentation, instead of relying on a mix of messages and docs. Centralization, etc. Also, issuing messages with recommendations on best statitistical practice feels like a rabbit hole I don’t want to get into. What deserves a warning? People do plenty of stuff I don’t agree with, and I can give some hints in the vignettes, but I think my job is to clearly define the available quantities and to give tools. I realize this is very different from the easystats philosophy…

p_avg1 replicates the original dataset 4 times: once for each unique value of period. p_avg2 replicates the full dataset once for every unique combination of period and herd. Make the same calls without the avg_ and inspect the rowid and rowidcf columns. I think that’ll make it clear.

You can get pretty close to those results automagically with the by argument. But this takes averages on the empirical distribution of the data, it does not create a “counterfactual/balanced” grid of data by replicating the full dataset:

library(ggplot2)
library(patchwork)
library(marginaleffects)

gm1 <- lme4::glmer(cbind(incidence, size - incidence) ~ period + (1 | herd),
                   data = lme4::cbpp, family = binomial)

plot_predictions(gm1, by = "period")

mattansb commented 1 year ago

Thanks Vincent. Gold badge as usual 🥇!

I think perhaps the source of my confusion is what that while (avg_)predictions() has a variables argument, but plot_predictions() does not? That is, I can easly generate counterfactual average predictions with avg_predictions(), but I cannot plot them with plot_predictions().

# by
avg_predictions(gm1, by = "period")

plot_predictions(gm1, by = "period")

# datagrid
avg_predictions(gm1, newdata = datagrid(period = unique), by = "period")

plot_predictions(gm1, condition = "period")

# counterfactual
avg_predictions(gm1, variables = "period") 

plot_predictions(gm1, ???)

I guess I'm expecting the plot_*() functions to be conviniant wrappers around avg_*() |> ggplot() + ....


Also, I think maybe (avg_)predictions() is the only place where the variables is used for generating a counterfactual grid? Elsewhere it is used to define the focal variable(s). Is this an inconsistency? Or something about the marginal-effects approach that I am missing or not understanding?

vincentarelbundock commented 1 year ago

That is, I can easly generate counterfactual average predictions with avg_predictions(), but I cannot plot them with plot_predictions().

Yes, that's a problem. I have not been able to come up with a good user interface, but will keep thinking about it.

Also, I think maybe (avg_)predictions() is the only place where the variables is used for generating a counterfactual grid? Elsewhere it is used to define the focal variable(s). Is this an inconsistency?

We had long discussions with Daniel and Noah about this. The TLDR is that this command computes counterfactual predicted values when am is 0 or 1:

predictions(mod, variables = list(am = 0:1))

And this command computes the differences between counterfactual predicted values when am is 0 or 1:

comparisons(mod, variables = list(am = 0:1))

So in that respect the syntax is highly consistent. In both cases, the "focal" variable is am. In the predictions() case, you just have to interpret "focal" in a counterfactual prediction way.

mattansb commented 1 year ago

We had long discussions with Daniel and Noah about this. [...]

Ah, I see! Okay, this makes sense to me - thanks!


Perhaps you can add a plot_predictions(variables = ) argument such that length(variables) + length(condition) <= 3 or length(variables) + length(by) <= 3,

So you have -

Is that too much of a headache? Yes.... I think it is.... Just throwing stuff into the air, see if I can inspire you somehow 😅

Or - maybe for condition add an argument grid_type =?

vincentarelbundock commented 1 year ago

Yeah, that's kind of the problem. Either we add a bunch of new arguments -- which introduces inconsistency -- or we end up replicating all the original predictions() function, in which case the user might as well pipe ggplot2.

My personal view is that in the vast majority of cases, if your data is not completely broken, these approaches will all produce about the same results. So the analysts can use the easy plotting functions at the development stage, and then build a custom plot for the final report/paper using ggplot2 and avg_predictions

mattansb commented 1 year ago

I'm now realizing that non of the plotting functions support counterfactual conditional values.

library(marginaleffects)

mod <- glm(gear ~ hp * factor(cyl) + am, 
           family = poisson("log"),
           data = mtcars)

avg_slopes(mod, variables = "hp", 
           by = "cyl")
#> 
#>  Term    Contrast cyl Estimate Std. Error     z Pr(>|z|)   2.5 % 97.5 %
#>    hp mean(dY/dX)   4  0.00655     0.0303 0.216    0.829 -0.0529 0.0660
#>    hp mean(dY/dX)   6  0.01446     0.0307 0.471    0.637 -0.0456 0.0746
#>    hp mean(dY/dX)   8  0.00598     0.0103 0.582    0.561 -0.0141 0.0261
#> 
#> Prediction type:  response 
#> Columns: type, term, contrast, cyl, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo 
#> 

avg_slopes(mod, variables = "hp", 
           newdata = datagrid(cyl = unique),
           by = "cyl")
#> 
#>  Term    Contrast cyl Estimate Std. Error     z Pr(>|z|)   2.5 % 97.5 %
#>    hp mean(dY/dX)   6  0.01557    0.03574 0.436    0.663 -0.0545 0.0856
#>    hp mean(dY/dX)   4  0.00677    0.03446 0.196    0.844 -0.0608 0.0743
#>    hp mean(dY/dX)   8  0.00556    0.00822 0.677    0.498 -0.0105 0.0217
#> 
#> Prediction type:  response 
#> Columns: rowid, type, term, contrast, cyl, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo 
#> 

avg_slopes(mod, variables = "hp", 
           newdata = datagridcf(cyl = unique),
           by = "cyl")
#> 
#>  Term    Contrast cyl Estimate Std. Error     z Pr(>|z|)   2.5 % 97.5 %
#>    hp mean(dY/dX)   4  0.00683    0.03508 0.195    0.846 -0.0619 0.0756
#>    hp mean(dY/dX)   6  0.01612    0.03926 0.411    0.681 -0.0608 0.0931
#>    hp mean(dY/dX)   8  0.00562    0.00841 0.668    0.504 -0.0109 0.0221
#> 
#> Prediction type:  response 
#> Columns: type, term, contrast, cyl, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo 
#> 

Only the former can be plotted by plot_slopes():

plot_slopes(mod, variables = "hp", by = "cyl", draw = FALSE)
#>       type term    contrast cyl    estimate  std.error statistic   p.value    conf.low  conf.high predicted predicted_hi predicted_lo
#> 1 response   hp mean(dY/dX)   4 0.006553553 0.03030894 0.2162251 0.8288123 -0.05285088 0.06595798  4.380574     4.380773     4.380574
#> 2 response   hp mean(dY/dX)   6 0.014456390 0.03066238 0.4714700 0.6373052 -0.04564076 0.07455354  4.093008     4.093442     4.093008
#> 3 response   hp mean(dY/dX)   8 0.005976243 0.01026863 0.5819901 0.5605733 -0.01414991 0.02610239  2.960878     2.961031     2.960878

plot_slopes(mod, variables = "hp", condition = "cyl", draw = FALSE)
#>   rowid     type term contrast    estimate   std.error statistic   p.value    conf.low  conf.high predicted predicted_hi predicted_lo   gear       hp      am cyl
#> 1     1 response   hp    dY/dX 0.015573112 0.035739301 0.4357419 0.6630240 -0.05447463 0.08562085  4.155098     4.155539     4.155098 3.6875 146.6875 0.40625   6
#> 2     2 response   hp    dY/dX 0.006766503 0.034457542 0.1963722 0.8443189 -0.06076904 0.07430204  4.223839     4.224030     4.223839 3.6875 146.6875 0.40625   4
#> 3     3 response   hp    dY/dX 0.005562088 0.008215769 0.6770015 0.4984050 -0.01054052 0.02166470  3.058014     3.058171     3.058014 3.6875 146.6875 0.40625   8

My mental model of plot_*() was that plot_*(..., draw = FALSE) was equal to avg_*(...)... I now understand this is not the case.

However, this also means that the error produced by avg_*() |> plot() is somewhat misleading:

avg_slopes(mod, variables = "hp", 
           newdata = datagridcf(cyl = unique),
           by = "cyl") |> 
  plot()
#> Error: Please use the `plot_slopes()` function.
mattansb commented 1 year ago

(As you can understand, I've started using marginaleffects in my consultation work, and I'm still building my mental model of the pipeline!)

vincentarelbundock commented 1 year ago

OK, I thought about this some more. I don’t want to implement a variables argument in plot_predictions() because it would make the interface confusing. However, I just pushed a commit to Github which adds a newdata argument to all three of the plotting functions. This makes it easy to achieve what you wanted:

library(marginaleffects)

mod <- glm(gear ~ hp * factor(cyl) + am, family = poisson("log"), data = mtcars)

avg_predictions(mod, newdata = datagridcf(cyl = unique), by = "cyl")
# 
#  cyl Estimate Pr(>|z|) 2.5 % 97.5 %
#    4     4.22  0.00383  1.59  11.21
#    6     4.16  < 0.001  2.49   6.93
#    8     3.06  < 0.001  1.74   5.37
# 
# Columns: cyl, estimate, p.value, conf.low, conf.high

plot_predictions(mod, newdata = datagridcf(cyl = unique), by = "cyl", draw = FALSE)
#   cyl estimate      p.value conf.low conf.high
# 1   4 4.223839 3.825857e-03 1.591032 11.213357
# 2   6 4.155098 4.648110e-08 2.492896  6.925616
# 3   8 3.058014 1.021792e-04 1.740043  5.374260

plot_predictions(mod, newdata = datagridcf(cyl = unique), by = "cyl")

mattansb commented 1 year ago

Amazing - thanks for the explanations and the changes in docs and code. Very much appreciated!