vdorie / dbarts

Discrete Bayesian Additive Regression Trees Sampler
56 stars 20 forks source link

Get draws from the prior? #31

Open ignacio82 opened 4 years ago

ignacio82 commented 4 years ago

Is there a way of getting draws from the prior distribution using dbarts? I basically would like to do something like this:

pred0_prior <- predict(fit, as.matrix(x0)) # draws for y(0) from the prior distribution

pred1 _prior <- predict(fit, as.matrix(x1)) # draws for y(1) from the prior distribution

tau_prior _draws <- pred1_prior  - pred0_prior  # prior ditribution of the treatment effect

In Stan, for example, you can add a switch to your code so it does not evaluates the likelihood and returns draws from the prior.

Thanks!

vdorie commented 4 years ago

Sort of. This will get you draws from the prior of the tree parameters, but if you want the full data distribution you would have to add noise based on sigma, drawn from its prior.

f <- function(x) {
    10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
      10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100

x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

# n.chains = 1 just for simplicity
sampler <- dbarts(x, y, control = dbartsControl(n.chains = 1))

n.samples <- 50
prior_samples <- matrix(NA, n, n.samples)

for (i in seq_len(n.samples)) {
  sampler$sampleTreesFromPrior()

  prior_samples[,i] <- sampler$predict(x)
}
ignacio82 commented 4 years ago

Thanks a lot for helping Vincent. I'm tried to follow your example and I'm clearly not understanding something. I endup with a vector of zeros :-(

This is the code that I wrote:

library(dplyr)
library(dbarts)

# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120

my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
  rowwise()  %>%
  mutate(treatment = case_when(Z == "treatment" ~ 1,
                               Z == "control" ~ 0),
         X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
                       Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
         Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
         Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),               
         Y = case_when(Z == "treatment" ~ Y1,
                       Z == "control" ~ Y0),
         tau = Y1 - Y0) %>%
  tidyr::drop_na() #  remove NAs from negative X's

message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))

# Fit the model -----------------------------------------------------------

fit <- bart(
  y.train         = my_data$Y,
  x.train         = my_data %>% select(X, treatment),
  keeptrees       = T,
  verbose         = F
)
## Get the posterior distribution for the treatment effect

x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)

pred0 <- predict(fit, as.matrix(x0))

pred1 <- predict(fit, as.matrix(x1))

tau_posterior_draws <- pred1 - pred0  

SATE_posterior <- rowMeans(tau_draws)

message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE),2)}"))

## Get the prior distribution for the treatment effect

x <- my_data %>% select(X, treatment)
y <- my_data$Y

# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(x0, y, control = dbartsControl(n.chains = 1))
sampler_x1 <- dbarts(x1, y, control = dbartsControl(n.chains = 1))

n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)

for (i in seq_len(n.samples)) {
  sampler_x0$sampleTreesFromPrior()
  prior_samples_x0[,i] <- sampler_x0$predict(x0)
  sampler_x1$sampleTreesFromPrior()
  prior_samples_x1[,i] <- sampler_x1$predict(x1)
  prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}

SATE_prior <- colMeans(prior_samples_tau) # This is clearly wrong

Would you mind explaining what am I missing?

Thanks!

vdorie commented 4 years ago

Apologies, my mistake. I forgot about the node parameters themselves. I've checked in some code that will allow for sampling from their prior as well.

require(dbarts)

f <- function(x) {
    10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
      10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100

x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

sampler <- dbarts(x, y, control = dbartsControl(n.chains = 1))

n.samples <- 50
prior_samples <- matrix(NA, n, n.samples)

for (i in seq_len(n.samples)) {
  sampler$sampleTreesFromPrior()
  sampler$sampleNodeParametersFromPrior()

  prior_samples[,i] <- sampler$predict(x)
}
mean(prior_samples)
# [1] 15.83145
sd(prior_samples)
# [1] 7.195763
ignacio82 commented 4 years ago

Thanks a lot for the quick response. Could you confirm that the following code is right? In particular, you mentioned something about having to add draws from sigma which I did not do and I'm not sure if I have to do nor how to do it.

library(dplyr)
library(dbarts)

# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120

my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
  rowwise()  %>%
  mutate(treatment = case_when(Z == "treatment" ~ 1,
                               Z == "control" ~ 0),
         X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
                       Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
         Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
         Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),               
         Y = case_when(Z == "treatment" ~ Y1,
                       Z == "control" ~ Y0),
         tau = Y1 - Y0) %>%
  tidyr::drop_na() #  remove NAs from negative X's
#> Warning in sqrt(X): NaNs produced
#> Warning in rnorm(1, mean = 72 + 3 * sqrt(X), sd = 1): NAs produced

message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))
#> The true Sample Average Treatment Effect is 10.43

# Fit the model -----------------------------------------------------------

fit <- bart(
  y.train         = my_data$Y,
  x.train         = my_data %>% select(X, treatment),
  keeptrees       = T,
  verbose         = F
)
## Get the posterior distribution for the treatment effect

x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)

pred0 <- predict(fit, as.matrix(x0))

pred1 <- predict(fit, as.matrix(x1))

tau_posterior_draws <- pred1 - pred0  

SATE_posterior <- rowMeans(tau_posterior_draws)

message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE_posterior),2)}"))
#> BART finds an Average Treatment Effect equal to 9.21

## Get the prior distribution for the treatment effect

x <- my_data %>% select(X, treatment)
y <- my_data$Y

# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(x0, y, control = dbartsControl(n.chains = 1))
sampler_x1 <- dbarts(x1, y, control = dbartsControl(n.chains = 1))

n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)

for (i in seq_len(n.samples)) {
  sampler_x0$sampleTreesFromPrior()
  sampler_x0$sampleNodeParametersFromPrior()
  prior_samples_x0[,i] <- sampler_x0$predict(x0)
  sampler_x1$sampleTreesFromPrior()
  sampler_x1$sampleNodeParametersFromPrior()
  prior_samples_x1[,i] <- sampler_x1$predict(x1)
  prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}

SATE_prior <- colMeans(prior_samples_tau) # is this right?

message(glue::glue("Our prior for the Average Treatment Effect is {round(mean(SATE_prior),2)}"))
#> Our prior for the Average Treatment Effect is -0.12

# vizdraws::vizdraws(prior = SATE_prior, posterior = SATE_posterior, MME = 9)

Created on 2020-07-26 by the reprex package (v0.3.0)

vdorie commented 4 years ago

Sorry, but I've been really falling behind on email these days. The only issue I see is that when you create the sampler with treatment variables that are full 0s or 1s, they get dropped. Currently, the most elegant way of getting the sampler to keep cut points (and hence, tree structures) for the treatment variable is to set it after that sampler has been created and initialized. Something like:

train_data <- dbartsData(my_data %>% select(X,treatment), y)

sampler_x0 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(0, length(y)), "treatment")

sampler_x1 <- dbarts(train_data, y, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(1, length(y)), "treatment")

Everything else looks good.

ignacio82 commented 4 years ago

Thanks a lot Vincent. I'm leaving the code with the change you suggested in case it is useful for someone else in the future. I think it would make an interesting and useful vignette for the package.

Final question, I noticed that in this example the prior is not centered a zero. Is this just because n.samples <- 1000 ?

library(dplyr)
library(dbarts)

# create some fake data ---------------------------------------------------
set.seed(123)
N <- 120

my_data <- tibble(Z = sample(x = c("treatment", "control"), size=N, replace = TRUE, p = c(0.5,0.5))) %>%
  rowwise()  %>%
  mutate(treatment = case_when(Z == "treatment" ~ 1,
                               Z == "control" ~ 0),
         X = case_when(Z == "treatment" ~ rnorm(1, mean = 40, sd = 10),
                       Z == "control" ~ rnorm(1, mean = 20, sd = 10)),
         Y1 = rnorm(1, mean = 90+exp(0.06*X), sd = 1),
         Y0 = rnorm(1, mean = 72+3*sqrt(X), sd = 1),               
         Y = case_when(Z == "treatment" ~ Y1,
                       Z == "control" ~ Y0),
         tau = Y1 - Y0) %>%
  tidyr::drop_na() #  remove NAs from negative X's
#> Warning in sqrt(X): NaNs produced
#> Warning in rnorm(1, mean = 72 + 3 * sqrt(X), sd = 1): NAs produced

message(glue::glue("The true Sample Average Treatment Effect is {round(mean(my_data$tau),2)}"))
#> The true Sample Average Treatment Effect is 10.43

# Fit the model -----------------------------------------------------------

fit <- bart(
  y.train         = my_data$Y,
  x.train         = my_data %>% select(X, treatment),
  keeptrees       = T,
  verbose         = F
)
## Get the posterior distribution for the treatment effect

x0 <- my_data %>% mutate(treatment=0) %>% select(X,treatment)
x1 <- my_data %>% mutate(treatment=1) %>% select(X,treatment)

pred0 <- predict(fit, as.matrix(x0))

pred1 <- predict(fit, as.matrix(x1))

tau_posterior_draws <- pred1 - pred0  

SATE_posterior <- rowMeans(tau_posterior_draws)

message(glue::glue("BART finds an Average Treatment Effect equal to {round(mean(SATE_posterior),2)}"))
#> BART finds an Average Treatment Effect equal to 9.21

## Get the prior distribution for the treatment effect

train_data <- dbartsData(my_data %>% select(X,treatment), my_data$Y)

# n.chains = 1 just for simplicity
sampler_x0 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(0, nrow(my_data)), "treatment")
#> [1] TRUE

sampler_x1 <- dbarts(train_data, control = dbartsControl(n.chains = 1))
sampler_x0$setPredictor(rep_len(1, nrow(my_data)), "treatment")
#> [1] TRUE

n.samples <- 1000
prior_samples_x0 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_x1 <- matrix(NA, nrow(my_data), n.samples)
prior_samples_tau <- matrix(NA, nrow(my_data), n.samples)

for (i in seq_len(n.samples)) {
  sampler_x0$sampleTreesFromPrior()
  sampler_x0$sampleNodeParametersFromPrior()
  prior_samples_x0[,i] <- sampler_x0$predict(x0) 
  sampler_x1$sampleTreesFromPrior()
  sampler_x1$sampleNodeParametersFromPrior()
  prior_samples_x1[,i] <- sampler_x1$predict(x1) 
  prior_samples_tau[,i] <- prior_samples_x1[,i] - prior_samples_x0[,i]
}

SATE_prior <- colMeans(prior_samples_tau) 

message(glue::glue("Our prior for the Average Treatment Effect is {round(mean(SATE_prior),2)}"))
#> Our prior for the Average Treatment Effect is 0.43

# vizdraws::vizdraws(prior = SATE_prior, posterior = SATE_posterior, MME = 9)

Created on 2020-08-07 by the reprex package (v0.3.0)

vdorie commented 4 years ago

Vignettes sound great, and I would love to be able to have some if I had a bit more time. And yes, if you look at the histogram of SATE_prior it's centered around 0. Increase the sample size to 10k and you can start to see convergence in mean.

ignacio82 commented 4 years ago

@vdorie I just sent you a pull request with a vignette showing how to do this. I imagine others might find it useful.