vdorie / dbarts

Discrete Bayesian Additive Regression Trees Sampler
56 stars 20 forks source link

fitted trees #44

Closed topepo closed 2 years ago

topepo commented 2 years ago

Is there any way to access the tree structures (assuming keeptrees = TRUE)?

It would be really helpful to be able to do come characterizations (beyond varcount). For example the tree size is printed but it would be good to be able to compute some of these things.

vdorie commented 2 years ago

I've added the ability to extract(bart_fit, "trees") and written a short vignette on how to interpret the results.

topepo commented 2 years ago

Thank you. This is excellent.

One final question... are the trees used for prediction and those from the last MCMC iteration? I assume that the prediction is based on a static set of trees and not some sampling of the posterior distribution of trees.

vdorie commented 2 years ago

The trees that are used when calling the predict function are all of the trees from every posterior sample drawn when the model is fit. This corresponds to the tree used when passing a test argument and obtaining predictions as the sampler is running. New samples are not drawn each time the predict function is called.

If a model is not fit with keepTrees = TRUE, only the current set of trees is saved and those trees do correspond to the final posterior sample after model fitting. The R predict generic is disabled in this scenario, but the corresponding low-level implementation does exist.

topepo commented 2 years ago

Thanks. Just to clarify with an example, if I use

library(dbarts)
set.seed(1)
fit <-
  bart(
    mtcars[, -1],
    mtcars$mpg,
    keeptrees = TRUE,
    ntree = 11,
    nchain = 3,
    ndpost = 20,
    nskip = 10,
    verbose = FALSE
  )

I get back 11 * 3 * 20 = 660 trees. New (mean) predictions are the averages of the predictions from these 660 trees?

vdorie commented 2 years ago

You would get, for each individual, 20 * 3 draws from the posterior distribution of their sum of the 11 trees. Something like, $\hat{mu{il}} = \sum{k}^{11} T_{kl}(x_i)$, where $i$ is a person index, $l$ is the sample index (running from 1 to 20 x 3), and $T$ is a draw from the posterior of tree $k$ in sample $l$.

library(dbarts)
set.seed(1)

fit <-
  bart(
    mtcars[, -1],
    mtcars$mpg,
    keeptrees = TRUE,
    ntree = 11,
    nchain = 3,
    ndpost = 20,
    nskip = 10,
    verbose = FALSE
  )

mu <- predict(fit, mtcars)
stopifnot(dim(mu) == c(3 * 20, nrow(mtcars)))

trees <- extract(fit, "trees")
dim(trees)
# "trees" are actually all nodes
# [1] 2318    6

stopifnot(with(trees, length(unique(interaction(chain, sample, tree)))) == 660)

In order to get the predictions of each sample of each tree for each individual (which is an array of (n.chains x n.samples x n.trees) x n.individuals), the trees have to be traversed.

getPredictionsForTree <- function(tree, x) {
    predictions <- rep(NA_real_, nrow(x))

    getPredictionsForTreeRecursive <- function(tree, indices) {
        if (tree$var[1] == -1) {
            predictions[indices] <<- tree$value[1]
            return(1)
        }

        goesLeft <- x[indices, tree$var[1]] <= tree$value[1]
        headOfLeftBranch <- tree[-1,]
        n_nodes.left <- getPredictionsForTreeRecursive(
            headOfLeftBranch, indices[goesLeft])

        headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
        n_nodes.right <- getPredictionsForTreeRecursive(
            headOfRightBranch, indices[!goesLeft])

        return(1 + n_nodes.left + n_nodes.right)
    }

    getPredictionsForTreeRecursive(tree, seq_len(nrow(x)))

    predictions
}

allPredictions <- by(
    trees,
    trees[,c("chain", "sample", "tree")],
    getPredictionsForTree,
    x = fit$fit$data@x
)

stopifnot(dim(allPredictions) == c(3, 20, 11))

stopifnot(length(allPredictions[[1]]) == nrow(mtcars))