JuliaAI / DecisionTree.jl

Julia implementation of Decision Tree (CART) and Random Forest algorithms
Other
356 stars 102 forks source link

Custom stopping criteria and loss functions #211

Open fipelle opened 1 year ago

fipelle commented 1 year ago

Hi,

I can't seem to find the documentation for creating custom stopping criteria (ideally for ensembles) and loss functions. Could you please point me in the right direction? Thanks!

ablaom commented 1 year ago

I'm not sure this is provided by this package, but you can get it using the MLJ wrapper:

IteratedModel docs MNIST / Flux example of iterative model control

Is this what you're after?

ablaom commented 1 year ago

Mmm... I see warm restart has not been implemented for the wrapper, which will make run time very slow for large numbers of iterations. I've posted https://github.com/JuliaAI/MLJDecisionTreeInterface.jl/issues/40 in response.

fipelle commented 1 year ago

I am trying to use a random forest classifier with:

Similar features are implemented in other packages such as LightGBM. See, for instance, the links below:

I was hoping to be able to do something similar with DecisionTree.jl directly.

ablaom commented 1 year ago

Yes, I understand. I just don't think that functionality exists here. I'll leave the issue open, and perhaps someone will add it.

For my part, I'd rather prioritise model-generic solutions to solutions to controlling iterative models, which is what MLJIteration does. That way we avoid a lot of duplication of effort.

fipelle commented 1 year ago

@ablaom I think I figured how to do it using native APIs.

In the case of classification trees, this is easy enough. All you need to do it to do something along the lines of build_tree(labels, features, loss=(ns, n)->custom_loss(ns, n, args...)). In the case of the Gini impurity, ns is the vector of cases per class (at leaf level) and n is the number of classes. Of course, correct me if I am wrong.

In the case of a random forest, things are a little more complicated. Is there any way around writing a custom build_forest function with a user-defined loss instead of the default util.entropy(ns, n, entropy_terms)? Regarding changes in the bootstrap samples (when needed), they have to be implemented modifying inds at https://github.com/JuliaAI/DecisionTree.jl/blob/f57a15633f5aadadfc408a8d5e42836e1f011c3f/src/classification/main.jl#L378

and https://github.com/JuliaAI/DecisionTree.jl/blob/f57a15633f5aadadfc408a8d5e42836e1f011c3f/src/classification/main.jl#L394

I suppose.

EDIT

I think it could be nice to extend DecisionTree or have a small package with more flexible versions of build_forest to accomodate for custom usage. What do you think? I am happy creating a new package if you'd prefer to keeps things separate (a bit like StatsPlots.jl).

ablaom commented 1 year ago

In the case of a random forest, things are a little more complicated. Is there any way around writing a custom build_forest function with a user-defined loss instead of the default util.entropy(ns, n, entropy_terms)?

Yes, I also recently discovered that the loss parameter is only exposed for single trees, and not forests. I'd definitely support fixing this and will open an issue.

fipelle commented 1 year ago

@ablaom I have almost finished writing a custom implementation that allows for custom bootstrapping as well (e.g., stratified sampling). Do you think it would be best to keep it separate or would you accept a pull request with it as well?

ablaom commented 1 year ago

Glad to hear about the progress. I think to reduce the maintenance burden on this package I'd prefer not to add model-generic functionality within the package itself. MLJ and other toolboxes provide for things like stratified resampling.

For example:

using MLJ

X, y = make_blobs(centers=5)

Tree = @load DecisionTreeClassifier pkg=DecisionTree

tree = Tree()

julia> evaluate(
       tree,
       X,
       y,
       resampling=StratifiedCV(nfolds=5),
       measure=LogLoss(),
       )
PerformanceEvaluation object with these fields:
  measure, operation, measurement, per_fold,
  per_observation, fitted_params_per_fold,
  report_per_fold, train_test_rows
Extract:
┌────────────────────────────────┬───────────┬─────────────┬─────────┬──────────────────────────────┐
│ measure                        │ operation │ measurement │ 1.96*SE │ per_fold                     │
├────────────────────────────────┼───────────┼─────────────┼─────────┼──────────────────────────────┤
│ LogLoss(                       │ predict   │ 5.05        │ 2.3     │ [5.41, 7.21, 1.8, 3.6, 7.21] │
│   tol = 2.220446049250313e-16) │           │             │         │                              │
└────────────────────────────────┴───────────┴─────────────┴─────────┴──────────────────────────────┘

I suggest that if MLJ has a feature you're missing that you open an issue there - and maybe even help provide it. The impact will be greater and the maintenance burden lower.

ablaom commented 1 year ago

@fipelle When this PR merges you will be able to (efficiently) control early stopping (and more) through the MLJ interface. A RandomForestClassifier example is given in the PR. Another example is this notebook.