vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
456 stars 47 forks source link

Support BART (`dbarts`, `BART`) #940

Closed ngreifer closed 1 year ago

ngreifer commented 1 year ago

BART is a Bayesian machine learning method that produces a posterior for each prediction. It is considered the gold standard in causal inference. It is implemented in the dbarts and BART packages. I am more familiar with dbarts, which has a predict() function for bart objects. I don't really understand how Bayesian models are processed in marginaleffects, or I'd submit this as a pull request. It should rely on the usual machinery for Bayesian models, but I think you will need to manually let the functions know that this model is Bayesian. This would be a major boon to the causal inference world, effectively eliminating the need for the bartCause package, which provides a walled garden for similar analyses.

vincentarelbundock commented 1 year ago

I merged support for dbarts.

I'm not 100% sure, but it looks like BART only accepts X and y matrices as input, and doesn't have a formula interface. That's a big problem for us, because we always build the contrast datasets from data frames, and then rely on predict() to build the model matrices internally. I think it would be very hard to support models of this type: fit(X, y).

A workaround is to go through tidymodels. I believe they support BART, and they also offer a formula frontend to fit models, which allows us to operate as usual. The major downside of this is that I'm not sure how to get all the draws from predictions, so we can't report uncertainty.

ngreifer commented 1 year ago

Thanks, I think support dbarts is enough. Though I am seeing something I don't understand:

data("lalonde", package = "MatchIt")
library(dbarts); library(marginaleffects)

fit <- bart2(re78 ~ treat + age + educ + race + married + nodegree + re74 + re75,
             data = lalonde, keepTrees = T, verbose = F)

avg_comparisons(fit, variables = "treat", by = "treat")
#> 
#>   Term          Contrast treat Estimate 2.5 % 97.5 %
#>  treat mean(1) - mean(0)     1     1002 -1756   3583
#>  treat mean(1) - mean(0)     0     1513  -109   3149
#> 
#> Columns: term, contrast, treat, estimate, conf.low, conf.high, predicted_lo, predicted_hi, predicted, tmp_idx 
#> Type:  ev
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 0))
#> Warning: The `treat` variable is treated as a categorical (factor) variable, but
#>   the original data is of class integer. It is safer and faster to convert
#>   such variables to factor before fitting the model and calling `slopes`
#>   functions.
#>   
#>   This warning appears once per session.
#> 
#>   Term Contrast Estimate 2.5 % 97.5 %
#>  treat    1 - 0     1300  -709   3412
#> 
#> Columns: term, contrast, estimate, conf.low, conf.high 
#> Type:  ev
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 1))
#> 
#>   Term Contrast Estimate 2.5 % 97.5 %
#>  treat    1 - 0     1393  -185   3003
#> 
#> Columns: term, contrast, estimate, conf.low, conf.high 
#> Type:  ev

fit <- lm(re78 ~ treat * (age + educ + race + married + nodegree + re74 + re75),
          data = lalonde)

avg_comparisons(fit, variables = "treat", by = "treat")
#> 
#>   Term          Contrast treat Estimate Std. Error     z Pr(>|z|)   S   2.5 %
#>  treat mean(1) - mean(0)     0      828       1317 0.629   0.5295 0.9 -1753.0
#>  treat mean(1) - mean(0)     1     1648        821 2.008   0.0447 4.5    39.3
#>  97.5 %
#>    3409
#>    3256
#> 
#> Columns: term, contrast, treat, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted 
#> Type:  response
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 0))
#> 
#>   Term Contrast Estimate Std. Error     z Pr(>|z|)   S 2.5 % 97.5 %
#>  treat    1 - 0      828       1317 0.629     0.53 0.9 -1753   3409
#> 
#> Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
#> Type:  response
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 1))
#> 
#>   Term Contrast Estimate Std. Error    z Pr(>|z|)   S 2.5 % 97.5 %
#>  treat    1 - 0     1648        821 2.01   0.0447 4.5  39.3   3256
#> 
#> Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
#> Type:  response

Created on 2023-10-20 with reprex v2.0.2 Why is subsetting with by not the same as manually supplying the newdata subsetted to each level of the by variables, as it is with linear regression?

vincentarelbundock commented 1 year ago

Good question. This is weird. Will look into it later, but leaving some code here to show that the newdata subsetting seems to work properly. Note the options() call to ensure that this is not an order of operations issue between mean and median.

options(marginaleffects_posterior_center = mean)
data("lalonde", package = "MatchIt")
library(dbarts); library(marginaleffects); library(collapse)

fit <- bart2(re78 ~ treat + age + educ + race + married + nodegree + re74 + re75,
             data = lalonde, keepTrees = T, verbose = F)

p0 <- predict(fit, newdata = transform(subset(lalonde, treat == 1), treat = 0))
p1 <- predict(fit, newdata = transform(subset(lalonde, treat == 1), treat = 1))
mean(p1 - p0)
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 1))

p0 <- predict(fit, newdata = transform(subset(lalonde, treat == 0), treat = 0))
p1 <- predict(fit, newdata = transform(subset(lalonde, treat == 0), treat = 1))
mean(p1 - p0)
avg_comparisons(fit, variables = "treat", newdata = subset(lalonde, treat == 0))

avg_comparisons(fit, variables = "treat", by = "treat")
vincentarelbundock commented 1 year ago

indexing hell

should be fixed on github now. Here are the tests in case you want to run them locally: https://github.com/vincentarelbundock/marginaleffects/blob/main/inst/tinytest/test-pkg-dbarts.R#L30