Closed topepo closed 2 years ago
I've added the ability to extract(bart_fit, "trees")
and written a short vignette on how to interpret the results.
Thank you. This is excellent.
One final question... are the trees used for prediction and those from the last MCMC iteration? I assume that the prediction is based on a static set of trees and not some sampling of the posterior distribution of trees.
The trees that are used when calling the predict
function are all of the trees from every posterior sample drawn when the model is fit. This corresponds to the tree used when passing a test
argument and obtaining predictions as the sampler is running. New samples are not drawn each time the predict
function is called.
If a model is not fit with keepTrees = TRUE
, only the current set of trees is saved and those trees do correspond to the final posterior sample after model fitting. The R predict
generic is disabled in this scenario, but the corresponding low-level implementation does exist.
Thanks. Just to clarify with an example, if I use
library(dbarts)
set.seed(1)
fit <-
bart(
mtcars[, -1],
mtcars$mpg,
keeptrees = TRUE,
ntree = 11,
nchain = 3,
ndpost = 20,
nskip = 10,
verbose = FALSE
)
I get back 11 * 3 * 20 = 660
trees. New (mean) predictions are the averages of the predictions from these 660 trees?
You would get, for each individual, 20 * 3
draws from the posterior distribution of their sum of the 11 trees. Something like, $\hat{mu{il}} = \sum{k}^{11} T_{kl}(x_i)$, where $i$ is a person index, $l$ is the sample index (running from 1 to 20 x 3), and $T$ is a draw from the posterior of tree $k$ in sample $l$.
library(dbarts)
set.seed(1)
fit <-
bart(
mtcars[, -1],
mtcars$mpg,
keeptrees = TRUE,
ntree = 11,
nchain = 3,
ndpost = 20,
nskip = 10,
verbose = FALSE
)
mu <- predict(fit, mtcars)
stopifnot(dim(mu) == c(3 * 20, nrow(mtcars)))
trees <- extract(fit, "trees")
dim(trees)
# "trees" are actually all nodes
# [1] 2318 6
stopifnot(with(trees, length(unique(interaction(chain, sample, tree)))) == 660)
In order to get the predictions of each sample of each tree for each individual (which is an array of (n.chains x n.samples x n.trees) x n.individuals)
, the trees have to be traversed.
getPredictionsForTree <- function(tree, x) {
predictions <- rep(NA_real_, nrow(x))
getPredictionsForTreeRecursive <- function(tree, indices) {
if (tree$var[1] == -1) {
predictions[indices] <<- tree$value[1]
return(1)
}
goesLeft <- x[indices, tree$var[1]] <= tree$value[1]
headOfLeftBranch <- tree[-1,]
n_nodes.left <- getPredictionsForTreeRecursive(
headOfLeftBranch, indices[goesLeft])
headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
n_nodes.right <- getPredictionsForTreeRecursive(
headOfRightBranch, indices[!goesLeft])
return(1 + n_nodes.left + n_nodes.right)
}
getPredictionsForTreeRecursive(tree, seq_len(nrow(x)))
predictions
}
allPredictions <- by(
trees,
trees[,c("chain", "sample", "tree")],
getPredictionsForTree,
x = fit$fit$data@x
)
stopifnot(dim(allPredictions) == c(3, 20, 11))
stopifnot(length(allPredictions[[1]]) == nrow(mtcars))
Is there any way to access the tree structures (assuming
keeptrees = TRUE
)?It would be really helpful to be able to do come characterizations (beyond
varcount
). For example the tree size is printed but it would be good to be able to compute some of these things.