vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
467 stars 47 forks source link

Multiple imputation in `inferences()` #669

Closed vincentarelbundock closed 1 year ago

vincentarelbundock commented 1 year ago

@ngreifer, I was reading your excellent new blog post on missing data, and it made me realize how clunky the workflow is for multiple imputation. I also realized that, by happy accident, it is almost trivial to use the existing inferences() infrastructure to make it much easier to pool estimates.

I updated the development version of the vignette: https://vincentarelbundock.github.io/marginaleffects/dev/articles/multiple_imputation.html

Here’s a quick example using version 0.9.0.9037:

library(marginaleffects)

# insert missing data
dat <- iris
dat$Sepal.Length[sample(seq_len(nrow(iris)), 40)] <- NA
dat$Sepal.Width[sample(seq_len(nrow(iris)), 40)] <- NA
dat$Species[sample(seq_len(nrow(iris)), 40)] <- NA

# impute
dat_mice <- mice::mice(dat, m = 20, printFlag = FALSE, .Random.seed = 1024)

# fit
mod <- lm(Petal.Width ~ Sepal.Length * Sepal.Width * Species, data = dat)

# listwise deletion
avg_slopes(mod, by = "Species")
# 
#          Term                        Contrast    Species Estimate Std. Error       z   Pr(>|z|)    2.5 % 97.5 %
#  Sepal.Length                     mean(dY/dX)     setosa  0.14567    0.19046  0.7648 0.44437148 -0.22763 0.5190
#  Sepal.Length                     mean(dY/dX) versicolor  0.12111    0.09525  1.2716 0.20353024 -0.06557 0.3078
#  Sepal.Length                     mean(dY/dX)  virginica  0.02187    0.09787  0.2235 0.82315181 -0.16995 0.2137
#   Sepal.Width                     mean(dY/dX)     setosa -0.04798    0.20792 -0.2308 0.81750137 -0.45550 0.3595
#   Sepal.Width                     mean(dY/dX) versicolor  0.23347    0.17512  1.3332 0.18248162 -0.10977 0.5767
#   Sepal.Width                     mean(dY/dX)  virginica  0.48254    0.18387  2.6243 0.00868274  0.12215 0.8429
#       Species mean(versicolor) - mean(setosa)     setosa  1.53486    0.31466  4.8779 1.0723e-06  0.91814 2.1516
#       Species mean(versicolor) - mean(setosa) versicolor  0.90883    0.42689  2.1289 0.03325849  0.07214 1.7455
#       Species mean(versicolor) - mean(setosa)  virginica  0.84863    0.50299  1.6872 0.09157188 -0.13722 1.8345
#       Species  mean(virginica) - mean(setosa)     setosa  2.18525    0.30076  7.2658 3.7094e-13  1.59577 2.7747
#       Species  mean(virginica) - mean(setosa) versicolor  1.43879    0.42865  3.3566 0.00078907  0.59866 2.2789
#       Species  mean(virginica) - mean(setosa)  virginica  1.44311    0.49711  2.9030 0.00369601  0.46879 2.4174
# 
# Prediction type:  response 
# Columns: type, term, contrast, Species, estimate, std.error, statistic, p.value, conf.low, conf.high, predicted, predicted_hi, predicted_lo

# multiple imputation
avg_slopes(mod, by = "Species") |> inferences(method = "mi", midata = dat_mice)
# 
#          Term                        Contrast    Species Estimate Std. Error       t   Pr(>|t|)    2.5 % 97.5 %
#  Sepal.Length                     mean(dY/dX)     setosa  0.16202    0.13790  1.1749 0.24592788 -0.11536 0.4394
#  Sepal.Length                     mean(dY/dX) versicolor  0.14456    0.06439  2.2450 0.02536017  0.01794 0.2712
#  Sepal.Length                     mean(dY/dX)  virginica  0.02698    0.06027  0.4476 0.65515420 -0.09220 0.1462
#   Sepal.Width                     mean(dY/dX)     setosa -0.01850    0.14438 -0.1282 0.89850036 -0.30810 0.2711
#   Sepal.Width                     mean(dY/dX) versicolor  0.33822    0.12109  2.7933 0.00558605  0.09985 0.5766
#   Sepal.Width                     mean(dY/dX)  virginica  0.48448    0.11569  4.1877 4.7836e-05  0.25589 0.7131
#       Species mean(versicolor) - mean(setosa)     setosa  1.31352    0.20033  6.5566 8.1499e-10  0.91771 1.7093
#       Species mean(versicolor) - mean(setosa) versicolor  0.84675    0.34144  2.4799 0.01817865  0.15321 1.5403
#       Species mean(versicolor) - mean(setosa)  virginica  0.83746    0.45829  1.8273 0.07649395 -0.09413 1.7690
#       Species  mean(virginica) - mean(setosa)     setosa  2.13301    0.20706 10.3015 < 2.22e-16  1.72158 2.5444
#       Species  mean(virginica) - mean(setosa) versicolor  1.39801    0.34483  4.0542 0.00026866  0.69783 2.0982
#       Species  mean(virginica) - mean(setosa)  virginica  1.39592    0.44737  3.1203 0.00367125  0.48676 2.3051
# 
# Prediction type:  response 
# Columns: type, term, contrast, Species, estimate, std.error, predicted, predicted_hi, predicted_lo, df, statistic, p.value, conf.low, conf.high
ngreifer commented 1 year ago

Thank you for the kind words :)

I don't know if I love that interface. Basically, it looks like mod is just a placeholder containing the original function call and dat_mice is just inserted into the data argument of the call. I could see how the bootstrapping infrastructure would make that easy. But I don't like that the user has to essentially fit a fake model just to get its call so that it can be extracted by inferences(). Ideally the user only fits the models that are actually being used to estimate quantities. For example, there may be no complete cases on which to fit the initial model. I also don't like method = "mi" since "pooling with Rubin's rules" is not a mode of inference, it's a method used on top of other modes (i.e., in theory, you could bootstrap in each imputed dataset and then combine, or you could do simulation inference in each model and combine).

I think an alternative would be to allow the models to accept a list of model fits or a mira object as an argument, which triggers a multiple imputation process that applies the estimation (i.e., slopes(), etc.) to each model/dataset and automatically combines the results, or requires a pool() call at the end that is smarter about how the estimates are combined than it currently is. Really it's only the last step that is holding things back. Letting users supply a list of model fits would just save them from having to use lapply().

vincentarelbundock commented 1 year ago

Thanks for these thoughts. I agree.

marginaleffects now supports mira objects, and the dev version of the vignette shows how to transform Amelia and missRanger datasets to mira format:

https://vincentarelbundock.github.io/marginaleffects/dev/articles/multiple_imputation.html