JuliaAI / DecisionTree.jl

Julia implementation of Decision Tree (CART) and Random Forest algorithms
Other
342 stars 101 forks source link

Is out-of-bag error of RandomForestClassifier implementable ? #223

Closed bentriom closed 1 year ago

bentriom commented 1 year ago

Hello,

Thanks for this Julia implementation of decision trees.

I tried to compute the out-of-bag error of a RandomForestClassifier. It is a widely used prediction error estimation in ensemble methods. I haven't found any function that implements such error estimation. So, I tried to implement it by myself. I dove into the code of build_forest, and I saw that neither the indexes of the bootstrap samples nor the random number generator are stored in the mutable struct Ensemble (reachable via RandomForestClassifier.ensemble).

Did I miss something that could get me the bootstrap sample indexes of forest built tree ? Else, is it not stored on purpose or this could be a new feature for the package ?

Thanks.

ablaom commented 1 year ago

@bentriom Thanks for digging into the weeds of DecisionTree.jl.

I agree with your assessment: Implementing out-of-bag estimates of the errors for the random forest models would require a non-trivial complication to the current design (and a decent amount of work, taking testing into account).

However, the concept of out-of-bag error estimates makes sense for any homogeneous ensemble model (not just forests) so I think this pkg isn't the best place to implement this.

In fact, MLJEnsembles.jl already provides the functionality you want. Here's an example:

using MLJ # or `MLJBase, MLJModels, MLJEnsembles` for minimal install

DecisionTreeClassifier = @iload DecisionTreeClassifier pkg=DecisionTree

atom = DecisionTreeClassifier()
model = EnsembleModel(
    atom;
    bagging_fraction=0.6,
    rng=123,
    out_of_bag_measure = [log_loss, brier_score],
)

X, y = @load_iris # a table and a vector

mach = machine(model, X, y) |> fit!

# julia> report(mach).oob_measurements
# 2-element Vector{Float64}:
#   2.1866483056064396
#  -0.12133333333333318

There may be some performance hit for using the MLJ interface (thanks to your post, I discovered the multithreading option in MLJEnsembles is broken). However, energy expended in improving that model-generic code is better than energy spent adding features to DecisionTree.jl, in my view. (As an author/maintainer of MLJ, I am of course biased.)

What do you think?

bentriom commented 1 year ago

Hello,

Thank you for this complete answer. Through your post I discovered this excellent feature of EnsembleModel. I would agree with you: it is better to improve the more generic method !

Thanks again.