vdorie / stan4bart

Uses Stan sampler and math library to semiparametrically fit linear and multilevel models with additive Bayesian Additive Regression Tree (BART) components.
40 stars 4 forks source link
bart bayesian stan

This package is an implementation of a C++ sampler that uses BART for non-parametric mean components and Stan for multilevel/parametric ones.

Installation

  1. Install the developer tools for your platform (Mac OS, Windows). Mac OS users will need the (linked) gfortan for their respective platforms.
    • If on an ARM processor, install gfortran from here
    • If on an Intel processor, install gfortran from here
  2. Execute:
if (length(find.package("remotes", quiet = TRUE)) == 0L)
  install.packages("remotes")
remotes::install_github("vdorie/dbarts")
remotes::install_github("vdorie/stan4bart")

Use

The package utilizes the flexible, expressive lme4 syntax for specifying group-level structures. See the package documentation ?stan4bart and ?stan4bart::stan4bart-generics for more information.

Formulas

Main Event

The main function is stan4bart.

Results

Results are retrieved using the extract, fitted, and predict generics. See ?"stan4bart-generics" for more information.

Known Issues

The name and definition of extract conflict with rstan. The rstan package is not needed to use stan4bart and does not need to be loaded. If a name-collision occurs, the stan4bart extract can be referenced as in:

stan4bart:::extract.stan4bartFit(stan4bart_fit)

Example Code

library(stan4bart)

# Load a test-data function
source(system.file("common", "friedmanData.R", package = "stan4bart"), local = TRUE)

# Relatively low n for illustrative purposes
testData <- generateFriedmanData(n = 100, ranef = TRUE, causal = TRUE, binary = FALSE)

# First level model is:
#   y ~ f(x_1, x_2) + a * x_3^2 + b * x_4 + c * x_5 + z
# Random intercepts are added for g.1 and g.2, and a random slope is placed on x4
# x_6 through x_10 are pure noise
df <- with(testData, data.frame(x, g.1, g.2, y, z))

# Causal inference example
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2), df,
                 treatment = z,
                 cores = 1, seed = 0,
                 verbose = 1)

samples.mu.train <- extract(fit)
samples.mu.test  <- extract(fit, sample = "test")

# Individual conditional treatment effects
samples.icate <- (samples.mu.train - samples.mu.test) * (2 * testData$z - 1)
# Conditional average treatment effect
samples.cate <- apply(samples.icate, 2, mean)
cate <- mean(samples.cate)
cate.int <- c(cate - 1.96 * sd(samples.cate), cate + 1.96 * sd(samples.cate))

# Samples of the posterior predictive distribution are used in calculating
# the counterfactuals for SATE and for calculating the response under
# the observed treatment condition when estimating PATE.
samples.ppd.test <- extract(fit, type = "ppd", sample = "test")

# Individual sample treatment effects
samples.ite <- (testData$y - samples.ppd.test) * (2 * testData$z - 1)
# Sample average treatment effect
samples.sate <- apply(samples.ite, 2, mean)
sate <- mean(samples.sate)
sate.int <- c(sate - 1.96 * sd(samples.sate), sate + 1.96 * sd(samples.sate))

# Population average treatment effect
samples.ppd.train <- extract(fit, type = "ppd", sample = "train")
samples.pate <- apply((samples.ppd.train - samples.ppd.test) * (2 * testData$z - 1), 2, mean)
pate <- mean(samples.pate)
pate.int <- c(pate - 1.96 * sd(samples.pate), pate + 1.96 * sd(samples.pate))

fitted.mu.train <- fitted(fit)
# equal to: apply(samples.mu.train, 1, mean)
fitted.mu.test  <- fitted(fit, sample = "test")
# equal to: apply(samples.mu.test,  1, mean)

# Observed and conterfactual MSE
mse.train <- with(testData, mean((fitted.mu.train - mu.1 * z - mu.0 * (1 - z))^2))
mse.test  <- with(testData, mean((fitted.mu.test  - mu.1 * (1 - z) - mu.0 * z)^2))