StochasticTree / stochtree

Stochastic tree ensembles (BART / XBART) for supervised learning and causal inference
Other
16 stars 6 forks source link

Credible intervals with Bayesian Causal Forest in Python #72

Closed d-vct closed 3 months ago

d-vct commented 3 months ago

Hello ! I am using the Bayesian Causal Forest method to perform causal inference. I was wondering if it is possible to retrieve credible intervals for the CATE and the potential outcomes. It seems that we can only retrieve the point estimates. Thank you!

andrewherren commented 3 months ago

Hi there, thanks for reaching out!

The sample() method of the BCFModel object returns posterior draws of the CATE function ($\tau(X)$ in the BCF paper notation), accessible as tau_hat_train (and tau_hat_test if you provided a test set when calling sample). These posterior draws can be used to compute point estimates such as a posterior mean, but also a credible interval.

As a quick demo, suppose you've fit a BCFModel as in the python causal vignette

bcf_model = BCFModel()
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test)

You can compute the posterior mean for the train and test sets as follows

tau_hat_train_mean = np.mean(bcf_model.tau_hat_train, axis = 1)
tau_hat_test_mean = np.mean(bcf_model.tau_hat_test, axis = 1)

Similarly, you can compute the 2.5th and 97.5th percentiles that define a 95% credible interval as follows

tau_hat_train_ci_lb = np.percentile(bcf_model.tau_hat_train, 2.5, axis=1)
tau_hat_test_ci_lb = np.percentile(bcf_model.tau_hat_test, 2.5, axis=1)
tau_hat_train_ci_ub = np.percentile(bcf_model.tau_hat_train, 97.5, axis=1)
tau_hat_test_ci_ub = np.percentile(bcf_model.tau_hat_test, 97.5, axis=1)
d-vct commented 3 months ago

Thank you for your quick response! And thank you for the Python implementation, it's very useful.