imbs-hl / ranger

A Fast Implementation of Random Forests
http://imbs-hl.github.io/ranger/
774 stars 193 forks source link

Feature request: option to remove/delete trees from "ranger" object #568

Closed bgreenwell closed 2 years ago

bgreenwell commented 3 years ago

Hi @mnwright ,

I'm writing to request a new feature in ranger, if feasible. In particular, it would be extremely useful to be able to remove/trim specific trees off of the fitted "ranger" object, or, at the very least, the ability to choose which trees to use in making predictions (rather than just the first num.trees trees). The primary reason is to be able to produce a potentially more accurate and leaner model. The more general idea is discussed in Friedman and Popescu (2003), which they refer to as "importance sampled learning ensembles" (ISLE); the idea is also briefly discussed in chapter 16 of The Elements of Statistical Learning. The basic idea is to use the LASSO to post-process a tree ensemble in the hopes of producing a much smaller model that's faster to train without sacrificing much in the way of accuracy, and in some cases, improving it.

A basic example using ranger is given below. However, it seems that the true benefit (i.e., a smaller and faster scoring model) could only be realized if we can actually discard the trees zeroed-out by the LASSO. As far as I'm aware, this is not possible in ranger, or any other tree-based ensemble package in R. Thoughts?

#
# Importance sampled learning ensemble (ISLE)
#
# An example of using the LASSO to post-process tree ensembles; in this case, a
# random forest fit using the {ranger} package.
#
# Source: https://statweb.stanford.edu/~jhf/ftp/isle.pdf (all see
# chapter 16 from The Elements of Statistical Learning)
#
# Required pkgs: AmesHousing, glmnet, ranger
#

################################################################################
# Setup
################################################################################

# Load required packages
library(glmnet)  # for post-processing
library(ranger)  # for random forest algorithm

# Load the Ames housing data
ames <- as.data.frame(AmesHousing::make_ames())
ames$Sale_Price <- ames$Sale_Price / 1000  # rescale response
set.seed(4919)
id <- sample.int(nrow(ames), size = floor(0.7 * nrow(ames)))
ames.trn <- ames[id, ]
ames.tst <- ames[-id, ]
xtst <- subset(ames.tst, select = -Sale_Price)  # features only

# Helper function for computing MSE as a function of number of trees
mse <- function(object, X, y) {
  p <- predict(object, data = X, predict.all = TRUE)$predictions
  sapply(seq_len(ncol(p)), FUN = function(i) {
    pred <- rowMeans(p[, seq_len(i), drop = FALSE])
    mean((pred - y) ^ 2)
  })
}

################################################################################
# Fit random forests
################################################################################

# Fit a default random forest (RFO)
set.seed(942)  # for reproducibility
system.time({
  rfo <- ranger(Sale_Price ~ ., data = ames.trn, num.trees = 1000)
})
#  user  system elapsed 
# 5.179   0.200   1.595

# Fit a random forest using shallow (depth-4) trees on 5% samples (RFO.4.5)
set.seed(1021)
system.time({
  rfo.4.5 <- ranger(Sale_Price ~ ., data = ames.trn, num.trees = 1000, 
                    max.depth = 4, sample.fraction = 0.05)
})
#  user  system elapsed 
# 0.268   0.006   0.110

# Test set MSE as a function of the number of trees
mse.rfo <- mse(rfo, X = xtst, y = ames.tst$Sale_Price)
mse.rfo.4.5 <- mse(rfo.4.5, X = xtst, y = ames.tst$Sale_Price)

################################################################################
# Post-process random forests
################################################################################

# LASSO-based post-processing function
postProcess <- function(X, y, newX, newy, offset = NULL, ...) {
  # Fit the LASSO path via coordinate descent
  lasso <- glmnet(X, y = y, intercept = TRUE, lower.limits = 0, 
                  standardize = FALSE, offset = offset, ...)
  mse <- assess.glmnet(lasso, newx = newX, newy = newy, 
                       family = "gaussian", newoffset = offset)$mse
  non.zero <- predict(lasso, newx = newX, type = "nonzero", 
                      newoffset = offset)
  ntree <- sapply(non.zero, FUN = length)
  res <- as.data.frame(cbind("ntree" = ntree, "mse" = mse))
  res <- res[order(res[["ntree"]], decreasing = FALSE), ]
  # Return minimum MSE for each value of ntree
  aggregate(mse ~ ntree, data = res, FUN = min)
}

# Post-process RFO ensemble
preds.trn <- predict(rfo, data = ames.trn, predict.all = TRUE)$predictions
preds.tst <- predict(rfo, data = ames.tst, predict.all = TRUE)$predictions
rfo.post <- postProcess(preds.trn, y = ames.trn$Sale_Price, 
                        newX = preds.tst, newy = ames.tst$Sale_Price)

# Post-process RFO.4.5 ensemble
preds.trn.4.5 <- predict(rfo.4.5, data = ames.trn, predict.all = TRUE)$predictions
preds.tst.4.5 <- predict(rfo.4.5, data = ames.tst, predict.all = TRUE)$predictions
rfo.4.5.post <- postProcess(preds.trn.4.5, y = ames.trn$Sale_Price,
                            newX = preds.tst.4.5, newy = ames.tst$Sale_Price)

# Plot results
palette("Okabe-Ito")
plot(mse.rfo, type = "l", ylim = c(range(mse.rfo, mse.rfo.4.5)),
     las = 1, xlab = "Number of trees", ylab = "Test MSE")
lines(mse.rfo.4.5, col = 2)
lines(rfo.post, lty = 2)
lines(rfo.4.5.post, col = 2, lty = 2)
legend("topright", legend = c("RF", "RF.4.5", "RF (post)", "RF.4.5 (post)"),
       col = c(1, 2, 1, 2), lty = c(1, 1, 2, 2), inset = 0.01, bty = "n")
palette("default")

image

mnwright commented 3 years ago

A function to remove (or keep?) some of the trees would need to change child.nodeIDs, split.varIDs and split.values in the $forest part of the object. It should also change predictions and prediction.error.

PR very welcome! :)

bgreenwell commented 3 years ago

Thanks for the feedback @mnwright, seems rather straightforward. Example given below using the fitted random forest from the example above. If you're good with it, I'll put together a PR with documentation, etc. However, I don't see how the predictions and prediction.error components can be updated properly since they seem to be returned from the underlying C code and are based on the aggregated predictions from the full forest? For now, I have it just spit out a warning regarding those components.

# Generic in case it can be extended to other random forest packages
deforest <- function(object, which.trees = NULL) {
  UseMethod("deforest")
}
deforest.ranger <- function(object, which.trees = NULL, warn = TRUE) {

  # Warn users about `predictions` and `prediction.error` components
  if (isTRUE(warn)) {
    warning("The `predictions` and `prediction.error` components of the ",
            "returned object are no longer correct as they correspond to the ",
            "original forest (with all trees).", call. = FALSE)
  }

  # "Remove trees" by removing necessary components from `forest` object
  object$forest$child.nodeIDs[which.trees] <- NULL
  object$forest$split.values[which.trees] <- NULL
  object$forest$split.varIDs[which.trees] <- NULL

  # Update `num.trees` components so `predict.ranger()` works
  object$forest$num.trees <- object$num.trees <- 
    length(object$forest$child.nodeIDs)

  # Return "deforested" forest
  object

}

# Remove first 999 trees (leaving only the last tree)
rfo.deforest <- deforest(rfo, which.trees = 1:999)
# Warning message:
# The `predictions` and `prediction.error` components of the returned object are 
# no longer correct as they correspond to the original forest (with all trees).

# Stack Ames training data 500 times (N = 1,025,500)
ames.big <- do.call("rbind", args = replicate(500, ames.trn, simplify = FALSE))

# Check scoring times
system.time({p <- predict(rfo, data = ames.big)})
#   user  system elapsed 
# 372.25    2.64   36.81
system.time({p.deforest <- predict(rfo.deforest, data = ames.big)})
# user  system elapsed 
# 3.71    0.53    3.92

# Sanity check
p <- predict(rfo, data = ames.trn[1L:10L, ], predict.all = TRUE)$predictions
p.deforest <- predict(rfo.deforest, data = ames.trn[1L:10L, ], predict.all = TRUE)$predictions
identical(p[, -(1:999), drop = FALSE], p.deforest)
# [1] TRUE
mnwright commented 3 years ago

Looks good. For predictions and prediction.error we would need keep.inbag = TRUE to re-calculate in R. Maybe set them to NA or something in addition to the warning?