Closed bgreenwell closed 2 years ago
A function to remove (or keep?) some of the trees would need to change child.nodeIDs
, split.varIDs
and split.values
in the $forest
part of the object. It should also change predictions
and prediction.error
.
PR very welcome! :)
Thanks for the feedback @mnwright, seems rather straightforward. Example given below using the fitted random forest from the example above. If you're good with it, I'll put together a PR with documentation, etc. However, I don't see how the predictions
and prediction.error
components can be updated properly since they seem to be returned from the underlying C code and are based on the aggregated predictions from the full forest? For now, I have it just spit out a warning regarding those components.
# Generic in case it can be extended to other random forest packages
deforest <- function(object, which.trees = NULL) {
UseMethod("deforest")
}
deforest.ranger <- function(object, which.trees = NULL, warn = TRUE) {
# Warn users about `predictions` and `prediction.error` components
if (isTRUE(warn)) {
warning("The `predictions` and `prediction.error` components of the ",
"returned object are no longer correct as they correspond to the ",
"original forest (with all trees).", call. = FALSE)
}
# "Remove trees" by removing necessary components from `forest` object
object$forest$child.nodeIDs[which.trees] <- NULL
object$forest$split.values[which.trees] <- NULL
object$forest$split.varIDs[which.trees] <- NULL
# Update `num.trees` components so `predict.ranger()` works
object$forest$num.trees <- object$num.trees <-
length(object$forest$child.nodeIDs)
# Return "deforested" forest
object
}
# Remove first 999 trees (leaving only the last tree)
rfo.deforest <- deforest(rfo, which.trees = 1:999)
# Warning message:
# The `predictions` and `prediction.error` components of the returned object are
# no longer correct as they correspond to the original forest (with all trees).
# Stack Ames training data 500 times (N = 1,025,500)
ames.big <- do.call("rbind", args = replicate(500, ames.trn, simplify = FALSE))
# Check scoring times
system.time({p <- predict(rfo, data = ames.big)})
# user system elapsed
# 372.25 2.64 36.81
system.time({p.deforest <- predict(rfo.deforest, data = ames.big)})
# user system elapsed
# 3.71 0.53 3.92
# Sanity check
p <- predict(rfo, data = ames.trn[1L:10L, ], predict.all = TRUE)$predictions
p.deforest <- predict(rfo.deforest, data = ames.trn[1L:10L, ], predict.all = TRUE)$predictions
identical(p[, -(1:999), drop = FALSE], p.deforest)
# [1] TRUE
Looks good. For predictions
and prediction.error
we would need keep.inbag = TRUE
to re-calculate in R. Maybe set them to NA or something in addition to the warning?
Hi @mnwright ,
I'm writing to request a new feature in ranger, if feasible. In particular, it would be extremely useful to be able to remove/trim specific trees off of the fitted
"ranger"
object, or, at the very least, the ability to choose which trees to use in making predictions (rather than just the firstnum.trees
trees). The primary reason is to be able to produce a potentially more accurate and leaner model. The more general idea is discussed in Friedman and Popescu (2003), which they refer to as "importance sampled learning ensembles" (ISLE); the idea is also briefly discussed in chapter 16 of The Elements of Statistical Learning. The basic idea is to use the LASSO to post-process a tree ensemble in the hopes of producing a much smaller model that's faster to train without sacrificing much in the way of accuracy, and in some cases, improving it.A basic example using ranger is given below. However, it seems that the true benefit (i.e., a smaller and faster scoring model) could only be realized if we can actually discard the trees zeroed-out by the LASSO. As far as I'm aware, this is not possible in ranger, or any other tree-based ensemble package in R. Thoughts?