Open ignacio82 opened 4 years ago
Sort of. This will get you draws from the prior of the tree parameters, but if you want the full data distribution you would have to add noise based on sigma, drawn from its prior.
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y <- rnorm(n, Ey, sigma)
# n.chains = 1 just for simplicity
sampler <- dbarts(x, y, control = dbartsControl(n.chains = 1))
n.samples <- 50
prior_samples <- matrix(NA, n, n.samples)
for (i in seq_len(n.samples)) {
sampler$sampleTreesFromPrior()
prior_samples[,i] <- sampler$predict(x)
}
Thanks a lot for helping Vincent. I'm tried to follow your example and I'm clearly not understanding something. I endup with a vector of zeros :-(
This is the code that I wrote:
library(dplyr)
library(dbarts)
# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120
my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
rowwise() %>%
mutate(treatment = case_when(Z == "treatment" ~ 1,
Z == "control" ~ 0),
X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),
Y = case_when(Z == "treatment" ~ Y1,
Z == "control" ~ Y0),
tau = Y1 - Y0) %>%
tidyr::drop_na() # remove NAs from negative X's
message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))
# Fit the model -----------------------------------------------------------
fit <- bart(
y.train = my_data$Y,
x.train = my_data %>% select(X, treatment),
keeptrees = T,
verbose = F
)
## Get the posterior distribution for the treatment effect
x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)
pred0 <- predict(fit, as.matrix(x0))
pred1 <- predict(fit, as.matrix(x1))
tau_posterior_draws <- pred1 - pred0
SATE_posterior <- rowMeans(tau_draws)
message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE),2)}"))
## Get the prior distribution for the treatment effect
x <- my_data %>% select(X, treatment)
y <- my_data$Y
# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(x0, y, control = dbartsControl(n.chains = 1))
sampler_x1 <- dbarts(x1, y, control = dbartsControl(n.chains = 1))
n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)
for (i in seq_len(n.samples)) {
sampler_x0$sampleTreesFromPrior()
prior_samples_x0[,i] <- sampler_x0$predict(x0)
sampler_x1$sampleTreesFromPrior()
prior_samples_x1[,i] <- sampler_x1$predict(x1)
prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}
SATE_prior <- colMeans(prior_samples_tau) # This is clearly wrong
Would you mind explaining what am I missing?
Thanks!
Apologies, my mistake. I forgot about the node parameters themselves. I've checked in some code that will allow for sampling from their prior as well.
require(dbarts)
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y <- rnorm(n, Ey, sigma)
sampler <- dbarts(x, y, control = dbartsControl(n.chains = 1))
n.samples <- 50
prior_samples <- matrix(NA, n, n.samples)
for (i in seq_len(n.samples)) {
sampler$sampleTreesFromPrior()
sampler$sampleNodeParametersFromPrior()
prior_samples[,i] <- sampler$predict(x)
}
mean(prior_samples)
# [1] 15.83145
sd(prior_samples)
# [1] 7.195763
Thanks a lot for the quick response. Could you confirm that the following code is right? In particular, you mentioned something about having to add draws from sigma which I did not do and I'm not sure if I have to do nor how to do it.
library(dplyr)
library(dbarts)
# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120
my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
rowwise() %>%
mutate(treatment = case_when(Z == "treatment" ~ 1,
Z == "control" ~ 0),
X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),
Y = case_when(Z == "treatment" ~ Y1,
Z == "control" ~ Y0),
tau = Y1 - Y0) %>%
tidyr::drop_na() # remove NAs from negative X's
#> Warning in sqrt(X): NaNs produced
#> Warning in rnorm(1, mean = 72 + 3 * sqrt(X), sd = 1): NAs produced
message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))
#> The true Sample Average Treatment Effect is 10.43
# Fit the model -----------------------------------------------------------
fit <- bart(
y.train = my_data$Y,
x.train = my_data %>% select(X, treatment),
keeptrees = T,
verbose = F
)
## Get the posterior distribution for the treatment effect
x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)
pred0 <- predict(fit, as.matrix(x0))
pred1 <- predict(fit, as.matrix(x1))
tau_posterior_draws <- pred1 - pred0
SATE_posterior <- rowMeans(tau_posterior_draws)
message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE_posterior),2)}"))
#> BART finds an Average Treatment Effect equal to 9.21
## Get the prior distribution for the treatment effect
x <- my_data %>% select(X, treatment)
y <- my_data$Y
# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(x0, y, control = dbartsControl(n.chains = 1))
sampler_x1 <- dbarts(x1, y, control = dbartsControl(n.chains = 1))
n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)
for (i in seq_len(n.samples)) {
sampler_x0$sampleTreesFromPrior()
sampler_x0$sampleNodeParametersFromPrior()
prior_samples_x0[,i] <- sampler_x0$predict(x0)
sampler_x1$sampleTreesFromPrior()
sampler_x1$sampleNodeParametersFromPrior()
prior_samples_x1[,i] <- sampler_x1$predict(x1)
prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}
SATE_prior <- colMeans(prior_samples_tau) # is this right?
message(glue::glue("Our prior for the Average Treatment Effect is {round(mean(SATE_prior),2)}"))
#> Our prior for the Average Treatment Effect is -0.12
# vizdraws::vizdraws(prior = SATE_prior, posterior = SATE_posterior, MME = 9)
Created on 2020-07-26 by the reprex package (v0.3.0)
Sorry, but I've been really falling behind on email these days. The only issue I see is that when you create the sampler with treatment variables that are full 0s or 1s, they get dropped. Currently, the most elegant way of getting the sampler to keep cut points (and hence, tree structures) for the treatment variable is to set it after that sampler has been created and initialized. Something like:
train_data <- dbartsData(my_data %>% select(X,treatment), y)
sampler_x0 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(0, length(y)), "treatment")
sampler_x1 <- dbarts(train_data, y, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(1, length(y)), "treatment")
Everything else looks good.
Thanks a lot Vincent. I'm leaving the code with the change you suggested in case it is useful for someone else in the future. I think it would make an interesting and useful vignette for the package.
Final question, I noticed that in this example the prior is not centered a zero. Is this just because n.samples <- 1000
?
library(dplyr)
library(dbarts)
# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120
my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
rowwise() %>%
mutate(treatment = case_when(Z == "treatment" ~ 1,
Z == "control" ~ 0),
X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),
Y = case_when(Z == "treatment" ~ Y1,
Z == "control" ~ Y0),
tau = Y1 - Y0) %>%
tidyr::drop_na() # remove NAs from negative X's
#> Warning in sqrt(X): NaNs produced
#> Warning in rnorm(1, mean = 72 + 3 * sqrt(X), sd = 1): NAs produced
message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))
#> The true Sample Average Treatment Effect is 10.43
# Fit the model -----------------------------------------------------------
fit <- bart(
y.train = my_data$Y,
x.train = my_data %>% select(X, treatment),
keeptrees = T,
verbose = F
)
## Get the posterior distribution for the treatment effect
x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)
pred0 <- predict(fit, as.matrix(x0))
pred1 <- predict(fit, as.matrix(x1))
tau_posterior_draws <- pred1 - pred0
SATE_posterior <- rowMeans(tau_posterior_draws)
message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE_posterior),2)}"))
#> BART finds an Average Treatment Effect equal to 9.21
## Get the prior distribution for the treatment effect
train_data <- dbartsData(my_data %>% select(X,treatment), my_data$Y)
# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(0, nrow(my_data)), "treatment")
#> [1] TRUE
sampler_x1 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(1, nrow(my_data)), "treatment")
#> [1] TRUE
n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)
for (i in seq_len(n.samples)) {
sampler_x0$sampleTreesFromPrior()
sampler_x0$sampleNodeParametersFromPrior()
prior_samples_x0[,i] <- sampler_x0$predict(x0)
sampler_x1$sampleTreesFromPrior()
sampler_x1$sampleNodeParametersFromPrior()
prior_samples_x1[,i] <- sampler_x1$predict(x1)
prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}
SATE_prior <- colMeans(prior_samples_tau)
message(glue::glue("Our prior for the Average Treatment Effect is {round(mean(SATE_prior),2)}"))
#> Our prior for the Average Treatment Effect is 0.43
# vizdraws::vizdraws(prior = SATE_prior, posterior = SATE_posterior, MME = 9)
Created on 2020-08-07 by the reprex package (v0.3.0)
Vignettes sound great, and I would love to be able to have some if I had a bit more time. And yes, if you look at the histogram of SATE_prior it's centered around 0. Increase the sample size to 10k and you can start to see convergence in mean.
@vdorie I just sent you a pull request with a vignette showing how to do this. I imagine others might find it useful.
Is there a way of getting draws from the prior distribution using
dbarts
? I basically would like to do something like this:In Stan, for example, you can add a switch to your code so it does not evaluates the likelihood and returns draws from the prior.
Thanks!