vdorie / dbarts

Discrete Bayesian Additive Regression Trees Sampler
56 stars 20 forks source link

Feature request: more convenient updates of splitting probabilities #67

Open EoghanONeill opened 8 months ago

EoghanONeill commented 8 months ago

To implement BART models with hyperpriors on splitting probabilities, for example as in Linero (2018), it is necessary to update the splitting probability in each MCMC iteration. It would be more convenient to do this with a function like setSplitProbabilities( ) for an object of the class dbartsSampler.

It is also possible to update splitting probabilities with setModel( ). I include an example below for anyone who would like to know how to update splitting probabilities. Perhaps an example with the Linero (2018) Dirichlet hyperprior could be added to a vignette.

# install.packages("dbarts")
library(dbarts)

f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100

x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

data <- data.frame(y, x)

control1 <- dbartsControl(n.samples = 1L,
                          n.chains = 1L,
                          n.threads = 1L,keepTrees = TRUE)

tempsplitprobs <- c( rep(0.5/9,9),0.5)

sampler1 <- dbarts(y ~ ., data =data, test = NULL, resid.prior = fixed(1), control = control1,
                   tree.prior = dbarts:::cgm(power = 2, base = 0.95,  split.probs = tempsplitprobs))
niter <- 5

for(i in 1:niter){
  samplestemp <- sampler1$run()
  print("sigma = ")
  print(samplestemp$sigma)
}

sampler1$getTrees()$var

tempcounts <- table(sampler1$getTrees()$var)[-1]
tempcounts

###### UPDATE SPLIT PROBABILITIES ##################
tempmodel <- sampler1$model
tempmodel@tree.prior@splitProbabilities <- c(1, rep(0/9,9))
sampler1$setModel(newModel = tempmodel)

niter <- 50
for(i in 1:niter){
  samplestemp <- sampler1$run()
  print("sigma = ")
  print(samplestemp$sigma)
}
tempcounts <- table(sampler1$getTrees()$var)[-1]

sampler1$model@tree.prior@splitProbabilities

tempcounts <- table(sampler1$getTrees()$var)[-1]
tempcounts

Linero, A. R. (2018). Bayesian regression trees for high-dimensional prediction and variable selection. Journal of the American Statistical Association, 113(522), 626-636.

yifei-philip commented 6 days ago

I used your method to update splitting prob. and alpha. However, my algorithm always gets stuck in alpha with very little value, which causes only a certain variable to split.

Do you have any ideas about it?

##### auxiliar functions #####

## calculate log(sum(exp(v))), avoiding overflow
log.sum.exp = function(v) {
  max.v = max(v)
  sum.exp.v.modify = sum(exp(v-max.v))
  return (max.v + log(sum.exp.v.modify))
}

## draw rho
draw_rho = function(varcount, p, alpha) {
  # rho = rdirichlet(1, as.vector((alpha / p ) + varcount))
  rho.par = as.vector((alpha / p ) + varcount)

  # Sample unnormalized s on the log scale
  templog.rho = rep(NA, p)
  for(i in 1:p) {
    templog.rho[i] = SoftBart:::rlgam(shape = rho.par[i]) # improvement of drawing from gamma distribution
  }

  # Normalize s on the log scale, then exponentiate
  log.rho = templog.rho - log.sum.exp(templog.rho)
  rho = exp(log.rho)

  return(rho)
}

## draw alpha
draw_alpha = function(rho, alpha.a0, alpha.b0, p) {
  x = 1:1000 / 1001 # generates grids of x = alpha/(alpha+p) ~ beta
  alpha.grid = x * p / (1 - x)

  alpha.log.lik = lgamma(alpha.grid) - p*lgamma(alpha.grid/p) + alpha.grid * mean(log(rho))
  beta.log.prior = (alpha.a0-1)*log(x) + (alpha.b0-1)*log(1-x)
  alpha.log.post = alpha.log.lik + beta.log.prior
  alpha.log.post.sum = log.sum.exp (alpha.log.post)

  # calculate the weights
  alpha.weight = exp(alpha.log.post - alpha.log.post.sum)
  # sample alpha
  alpha = sample(alpha.grid, 1, prob = alpha.weight)

  return(alpha) 
}

##### DART simulation #####
library(dbarts)
library(SoftBart)

## data generation
f = function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma = 1.0
n     = 100

x  = matrix(runif(n * 10), n, 10)
Ey = f(x)
y  = rnorm(n, Ey, sigma)

data = data.frame(y, x)

## bart sampler initialization
control = dbartsControl(n.samples = 1L,
                          n.chains = 1L,
                          n.threads = 1L,keepTrees = TRUE)

rho = rep(1/10, 10)

sampler = dbarts(y ~ ., data = data, test = NULL, control = control,
                  tree.prior = dbarts:::cgm(power = 2, base = 0.95,  split.probs = rho))

## MCMC process
alpha = 1.0
nburn = 1000L
nsave = 1000L
ntotal = nburn + nsave

# storage
sampler.store = list()
sigma.store = array(NA,nsave)
rho.store = array(NA,c(nsave,10))
alpha.store = array(NA,nsave)

for(nrep in 1:ntotal){
  # update split prob. after warmup iterations
  if(nrep > floor(nburn/2)){
    tempmodel = sampler$model
    tempmodel@tree.prior@splitProbabilities = rho
    sampler$setModel(newModel = tempmodel)
  }

  # sample from bart/dart
  sample = sampler$run()

  # update rho and alpha after warmup iterations
  if(nrep > floor(nburn/2)){
    # draw rho
    # tempcounts = table(sampler$getTrees()$var)[-1]
    tempcounts = as.vector(sample$varcount)
    rho = draw_rho(tempcounts, p = 10, alpha)
    # draw alpha
    alpha = draw_alpha(rho, alpha.a0 = 0.5, alpha.b0 = 1.0, p = 10)
  }

  # store samples
  if(nrep > nburn){
    n = nrep - nburn
    sampler.store[[n]] = sampler
    sigma.store[n] = sample$sigma
    rho.store[n,] = rho
    alpha.store[n] = alpha
  }
}
EoghanONeill commented 6 days ago

Thank you for testing the method for updating the splitting probability. I can confirm that when I test out the method with your simulated data one of the variables obtains a splitting probability of almost 1. If I increase the sample size to 500, the method appears to work as intended. However, the BART and SoftBART packages give good results if the sample size is 100. Perhaps there is still a bug in my code or dbarts requires more observations for good variable selection. I do not know what might cause this. Are there different minimum numbers of terminal node observations or different potential splitting points for the BART and dbarts packages?


##### DART simulation #####
library(dbarts)
library(SoftBart)

## data generation
f = function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
    10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma = 1.0
n     = 500

x  = matrix(runif(n * 10), n, 10)
Ey = f(x)
y  = rnorm(n, Ey, sigma)

data = data.frame(y, x)

## bart sampler initialization
control = dbartsControl(n.samples = 1L,
                        n.chains = 1L,  #n.trees = 20L,
                        n.threads = 1L,keepTrees = TRUE)

rho = rep(1/10, 10)

sampler = dbarts(y ~ ., data = data, test = NULL, control = control,proposal.probs = c(
  birth_death = 0.5, swap = 0, change = 0.5, birth = 0.5),
                 tree.prior = dbarts:::cgm(power = 2, base = 0.95,  split.probs = rho))

## MCMC process
alpha = 1.0
nburn = 1000L
nsave = 1000L
ntotal = nburn + nsave

# storage
sampler.store = list()
sigma.store = array(NA,nsave)
rho.store = array(NA,c(nsave,10))
alpha.store = array(NA,nsave)

varcount.store = array(NA,c(nsave,10))

s_y <- rho
p_y <- length(rho)
alpha_scale_y <- 1 # p_y
alpha_a_y <- 0.5
alpha_b_y <- 1
print.opt <- 50

alpha <- 1 # p_y
var_count_y <- rep(0, p_y)

for(nrep in 1:ntotal){
  # update split prob. after warmup iterations
  if(nrep > floor(nburn/2)){
    tempmodel = sampler$model
    tempmodel@tree.prior@splitProbabilities = s_y #rho
    sampler$setModel(newModel = tempmodel)
  }
  # sample from bart/dart
  sample = sampler$run()
  var_count_y <- rep(0,p_y)
  tempcounts <- collapse::fcount(sampler$getTrees()$var)
  tempcounts <- tempcounts[tempcounts$x != -1, ]
  var_count_y[tempcounts$x] <- tempcounts$N

  # update rho and alpha after warmup iterations
  if(nrep > floor(nburn/2)){
    # draw rho
    # tempcounts = table(sampler$getTrees()$var)[-1]
    # tempcounts = as.vector(sample$varcount)
    # rho = draw_rho(tempcounts, p = 10, alpha)
    # # draw alpha
    # alpha = draw_alpha(rho, alpha.a0 = 0.5, alpha.b0 = 1.0, p = 10)

    # var_count_y = table(sampler$getTrees()$var)[-1] #as.vector(sample$varcount)
    # s_update_y <- update_s(var_count_y, p_y, alpha)
    shape_up = as.vector((alpha / p_y ) + var_count_y)
    # // Sample unnormalized s on the log scale
    templogs = rep(NA, p_y)
    for(i in 1:p_y) {
      templogs[i] = SoftBart:::rlgam(shape = shape_up[i]) # lgamma(shape_up[i])
    }
    # // Normalize s on the log scale, then exponentiate
    # templogs = templogs - log_sum_exp(hypers.logs);
    max_log = max(templogs)
    templogs2 = templogs - (max_log + log(sum(exp( templogs  -  max_log ))))
    s_y <- exp(templogs2)

    rho_grid <- (1:1000)/1001
    alpha_grid <- alpha_scale_y * rho_grid / (1 - rho_grid )
    templogs1 = rep(NA, length(alpha_grid))
    templogs2 = rep(NA, length(alpha_grid))
    for(i in 1:length(alpha_grid)) {
      templogs1[i] = SoftBart:::rlgam(shape = alpha_grid[i]) # lgamma(shape_up[i])
    }
    for(i in 1:length(alpha_grid)) {
      templogs2[i] = SoftBart:::rlgam(shape = alpha_grid[i]/p_y) # lgamma(shape_up[i])
    }
    logliks <- alpha_grid * mean(templogs2) +
      templogs1 - # lgamma(alpha_grid) -
      p_y*templogs2 + # lgamma(alpha_grid/p_y) + # (alpha_a - 1)*log(rho_grid) + (alpha_b-1)*log(1- rho_grid)
      dbeta(x = rho_grid, shape1 = alpha_a_y, shape2 = alpha_b_y, ncp = 0, log = TRUE)

    max_ll <- max(logliks)
    logsumexps <- max_ll + log(sum(exp( logliks  -  max_ll )))
    logliks <- exp(logliks - logsumexps)
    rho_ind <- sample.int(1000,size = 1, prob = logliks)
    alpha <- alpha_grid[rho_ind]
  }
  # store samples
  if(nrep > nburn){
    varcount.store[iter_post,] = var_count_y # s_y# rho
    iter_post = nrep - nburn
    # sampler.store[[n]] = sampler
    sigma.store[iter_post] = sample$sigma
    rho.store[iter_post,] = s_y# rho
    alpha.store[iter_post] = alpha
  }
  if(nrep %% print.opt == 0){
    print(paste("Gibbs Iteration", nrep))
    # print(c(sigma2.alpha, sigma2.beta))
  }
}

apply(rho.store, 2,mean)

apply(varcount.store, 2,mean)

fit <- SoftBart::softbart(X = data[,2:ncol(data)], Y = data$y, X_test = data[,2:ncol(data)],# hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
                opts = SoftBart::Opts(num_burn = nburn, num_save = nsave, update_tau = TRUE)
                )
apply(fit$s, 2,mean)
apply(fit$var_counts, 2,mean)

fit2 <- BART::wbart(x.train =  data[,2:ncol(data)], y.train =  data$y, nskip = nburn, ndpost = nsave, sparse = TRUE, ntree = 75L )
apply(fit2$varprob, 2,mean)
apply(fit2$varcount, 2,mean)

fit2 <- BART::wbart(x.train =  data[,2:ncol(data)], y.train =  data$y, nskip = nburn, ndpost = nsave, 
                    sparse = TRUE, ntree = 75L, usequants = FALSE #,rho = 1, augment = TRUE
                    )
apply(fit2$varprob, 2,mean)
apply(fit2$varcount, 2,mean)