ModelOriented / randomForestExplainer

A set of tools to understand what is happening inside a Random Forest
https://ModelOriented.github.io/randomForestExplainer/
230 stars 37 forks source link

Exponentially faster tree depth #38

Closed mayer79 closed 8 months ago

mayer79 commented 8 months ago

This PR brings an exponential speed-up in calculating the min depth distribution. The speed gain is especially strong for very deep trees, as with random forests fitted on larger data.

The trick is to loop over tree depth instead of looping over tree nodes.

As such, it solves #34 and adds to PR #35 which mainly (and unfortunately) brought a speed-up only for small trees.

library(randomForest)
library(randomForestExplainer)
library(ranger)
library(ggplot2)

set.seed(12)

# Random forest
fit <- randomForest(price~carat+color+cut+clarity, data = diamonds, ntree = 100)
system.time(  # 24s -> 0.6s
  out <- min_depth_distribution(fit)
)
head(out)
#   tree variable minimal_depth
# 1    1    carat             2
# 2    1  clarity             0
# 3    1    color             2
# 4    1      cut             3
# 5    2    carat             2
# 6    2  clarity             3

# Ranger (seems to grow much deeper trees)
fit2 <- ranger(
  price~carat+color+cut+clarity, data = diamonds,
  num.trees = 100,
  max.depth = 10, # without this, the original depth calculation won't stop
  seed = 1
)
system.time(  # 19s -> 0.1s
  out <- min_depth_distribution(fit2)
)
head(out)

#   tree variable minimal_depth
# 1    1    carat             1
# 2    1  clarity             0
# 3    1    color             2
# 4    1      cut             2
# 5    2    carat             2
pbiecek commented 8 months ago

binary search wins! thanks