Open EoghanONeill opened 8 months ago
I used your method to update splitting prob. and alpha. However, my algorithm always gets stuck in alpha with very little value, which causes only a certain variable to split.
Do you have any ideas about it?
##### auxiliar functions #####
## calculate log(sum(exp(v))), avoiding overflow
log.sum.exp = function(v) {
max.v = max(v)
sum.exp.v.modify = sum(exp(v-max.v))
return (max.v + log(sum.exp.v.modify))
}
## draw rho
draw_rho = function(varcount, p, alpha) {
# rho = rdirichlet(1, as.vector((alpha / p ) + varcount))
rho.par = as.vector((alpha / p ) + varcount)
# Sample unnormalized s on the log scale
templog.rho = rep(NA, p)
for(i in 1:p) {
templog.rho[i] = SoftBart:::rlgam(shape = rho.par[i]) # improvement of drawing from gamma distribution
}
# Normalize s on the log scale, then exponentiate
log.rho = templog.rho - log.sum.exp(templog.rho)
rho = exp(log.rho)
return(rho)
}
## draw alpha
draw_alpha = function(rho, alpha.a0, alpha.b0, p) {
x = 1:1000 / 1001 # generates grids of x = alpha/(alpha+p) ~ beta
alpha.grid = x * p / (1 - x)
alpha.log.lik = lgamma(alpha.grid) - p*lgamma(alpha.grid/p) + alpha.grid * mean(log(rho))
beta.log.prior = (alpha.a0-1)*log(x) + (alpha.b0-1)*log(1-x)
alpha.log.post = alpha.log.lik + beta.log.prior
alpha.log.post.sum = log.sum.exp (alpha.log.post)
# calculate the weights
alpha.weight = exp(alpha.log.post - alpha.log.post.sum)
# sample alpha
alpha = sample(alpha.grid, 1, prob = alpha.weight)
return(alpha)
}
##### DART simulation #####
library(dbarts)
library(SoftBart)
## data generation
f = function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma = 1.0
n = 100
x = matrix(runif(n * 10), n, 10)
Ey = f(x)
y = rnorm(n, Ey, sigma)
data = data.frame(y, x)
## bart sampler initialization
control = dbartsControl(n.samples = 1L,
n.chains = 1L,
n.threads = 1L,keepTrees = TRUE)
rho = rep(1/10, 10)
sampler = dbarts(y ~ ., data = data, test = NULL, control = control,
tree.prior = dbarts:::cgm(power = 2, base = 0.95, split.probs = rho))
## MCMC process
alpha = 1.0
nburn = 1000L
nsave = 1000L
ntotal = nburn + nsave
# storage
sampler.store = list()
sigma.store = array(NA,nsave)
rho.store = array(NA,c(nsave,10))
alpha.store = array(NA,nsave)
for(nrep in 1:ntotal){
# update split prob. after warmup iterations
if(nrep > floor(nburn/2)){
tempmodel = sampler$model
tempmodel@tree.prior@splitProbabilities = rho
sampler$setModel(newModel = tempmodel)
}
# sample from bart/dart
sample = sampler$run()
# update rho and alpha after warmup iterations
if(nrep > floor(nburn/2)){
# draw rho
# tempcounts = table(sampler$getTrees()$var)[-1]
tempcounts = as.vector(sample$varcount)
rho = draw_rho(tempcounts, p = 10, alpha)
# draw alpha
alpha = draw_alpha(rho, alpha.a0 = 0.5, alpha.b0 = 1.0, p = 10)
}
# store samples
if(nrep > nburn){
n = nrep - nburn
sampler.store[[n]] = sampler
sigma.store[n] = sample$sigma
rho.store[n,] = rho
alpha.store[n] = alpha
}
}
Thank you for testing the method for updating the splitting probability. I can confirm that when I test out the method with your simulated data one of the variables obtains a splitting probability of almost 1. If I increase the sample size to 500, the method appears to work as intended. However, the BART and SoftBART packages give good results if the sample size is 100. Perhaps there is still a bug in my code or dbarts requires more observations for good variable selection. I do not know what might cause this. Are there different minimum numbers of terminal node observations or different potential splitting points for the BART and dbarts packages?
##### DART simulation #####
library(dbarts)
library(SoftBart)
## data generation
f = function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma = 1.0
n = 500
x = matrix(runif(n * 10), n, 10)
Ey = f(x)
y = rnorm(n, Ey, sigma)
data = data.frame(y, x)
## bart sampler initialization
control = dbartsControl(n.samples = 1L,
n.chains = 1L, #n.trees = 20L,
n.threads = 1L,keepTrees = TRUE)
rho = rep(1/10, 10)
sampler = dbarts(y ~ ., data = data, test = NULL, control = control,proposal.probs = c(
birth_death = 0.5, swap = 0, change = 0.5, birth = 0.5),
tree.prior = dbarts:::cgm(power = 2, base = 0.95, split.probs = rho))
## MCMC process
alpha = 1.0
nburn = 1000L
nsave = 1000L
ntotal = nburn + nsave
# storage
sampler.store = list()
sigma.store = array(NA,nsave)
rho.store = array(NA,c(nsave,10))
alpha.store = array(NA,nsave)
varcount.store = array(NA,c(nsave,10))
s_y <- rho
p_y <- length(rho)
alpha_scale_y <- 1 # p_y
alpha_a_y <- 0.5
alpha_b_y <- 1
print.opt <- 50
alpha <- 1 # p_y
var_count_y <- rep(0, p_y)
for(nrep in 1:ntotal){
# update split prob. after warmup iterations
if(nrep > floor(nburn/2)){
tempmodel = sampler$model
tempmodel@tree.prior@splitProbabilities = s_y #rho
sampler$setModel(newModel = tempmodel)
}
# sample from bart/dart
sample = sampler$run()
var_count_y <- rep(0,p_y)
tempcounts <- collapse::fcount(sampler$getTrees()$var)
tempcounts <- tempcounts[tempcounts$x != -1, ]
var_count_y[tempcounts$x] <- tempcounts$N
# update rho and alpha after warmup iterations
if(nrep > floor(nburn/2)){
# draw rho
# tempcounts = table(sampler$getTrees()$var)[-1]
# tempcounts = as.vector(sample$varcount)
# rho = draw_rho(tempcounts, p = 10, alpha)
# # draw alpha
# alpha = draw_alpha(rho, alpha.a0 = 0.5, alpha.b0 = 1.0, p = 10)
# var_count_y = table(sampler$getTrees()$var)[-1] #as.vector(sample$varcount)
# s_update_y <- update_s(var_count_y, p_y, alpha)
shape_up = as.vector((alpha / p_y ) + var_count_y)
# // Sample unnormalized s on the log scale
templogs = rep(NA, p_y)
for(i in 1:p_y) {
templogs[i] = SoftBart:::rlgam(shape = shape_up[i]) # lgamma(shape_up[i])
}
# // Normalize s on the log scale, then exponentiate
# templogs = templogs - log_sum_exp(hypers.logs);
max_log = max(templogs)
templogs2 = templogs - (max_log + log(sum(exp( templogs - max_log ))))
s_y <- exp(templogs2)
rho_grid <- (1:1000)/1001
alpha_grid <- alpha_scale_y * rho_grid / (1 - rho_grid )
templogs1 = rep(NA, length(alpha_grid))
templogs2 = rep(NA, length(alpha_grid))
for(i in 1:length(alpha_grid)) {
templogs1[i] = SoftBart:::rlgam(shape = alpha_grid[i]) # lgamma(shape_up[i])
}
for(i in 1:length(alpha_grid)) {
templogs2[i] = SoftBart:::rlgam(shape = alpha_grid[i]/p_y) # lgamma(shape_up[i])
}
logliks <- alpha_grid * mean(templogs2) +
templogs1 - # lgamma(alpha_grid) -
p_y*templogs2 + # lgamma(alpha_grid/p_y) + # (alpha_a - 1)*log(rho_grid) + (alpha_b-1)*log(1- rho_grid)
dbeta(x = rho_grid, shape1 = alpha_a_y, shape2 = alpha_b_y, ncp = 0, log = TRUE)
max_ll <- max(logliks)
logsumexps <- max_ll + log(sum(exp( logliks - max_ll )))
logliks <- exp(logliks - logsumexps)
rho_ind <- sample.int(1000,size = 1, prob = logliks)
alpha <- alpha_grid[rho_ind]
}
# store samples
if(nrep > nburn){
varcount.store[iter_post,] = var_count_y # s_y# rho
iter_post = nrep - nburn
# sampler.store[[n]] = sampler
sigma.store[iter_post] = sample$sigma
rho.store[iter_post,] = s_y# rho
alpha.store[iter_post] = alpha
}
if(nrep %% print.opt == 0){
print(paste("Gibbs Iteration", nrep))
# print(c(sigma2.alpha, sigma2.beta))
}
}
apply(rho.store, 2,mean)
apply(varcount.store, 2,mean)
fit <- SoftBart::softbart(X = data[,2:ncol(data)], Y = data$y, X_test = data[,2:ncol(data)],# hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
opts = SoftBart::Opts(num_burn = nburn, num_save = nsave, update_tau = TRUE)
)
apply(fit$s, 2,mean)
apply(fit$var_counts, 2,mean)
fit2 <- BART::wbart(x.train = data[,2:ncol(data)], y.train = data$y, nskip = nburn, ndpost = nsave, sparse = TRUE, ntree = 75L )
apply(fit2$varprob, 2,mean)
apply(fit2$varcount, 2,mean)
fit2 <- BART::wbart(x.train = data[,2:ncol(data)], y.train = data$y, nskip = nburn, ndpost = nsave,
sparse = TRUE, ntree = 75L, usequants = FALSE #,rho = 1, augment = TRUE
)
apply(fit2$varprob, 2,mean)
apply(fit2$varcount, 2,mean)
To implement BART models with hyperpriors on splitting probabilities, for example as in Linero (2018), it is necessary to update the splitting probability in each MCMC iteration. It would be more convenient to do this with a function like
setSplitProbabilities( )
for an object of the classdbartsSampler
.It is also possible to update splitting probabilities with
setModel( )
. I include an example below for anyone who would like to know how to update splitting probabilities. Perhaps an example with the Linero (2018) Dirichlet hyperprior could be added to a vignette.Linero, A. R. (2018). Bayesian regression trees for high-dimensional prediction and variable selection. Journal of the American Statistical Association, 113(522), 626-636.