JuliaAI / MLJDecisionTreeInterface.jl

MIT License
9 stars 4 forks source link

Plotting trees with TreeRecipe.jl #56

Open ablaom opened 6 months ago

ablaom commented 6 months ago

The following example shows how to manually plot the trees learned in DecisionTree.jl:

https://github.com/JuliaAI/TreeRecipe.jl/blob/master/examples/DecisionTree_iris.jl

Currently, the way to integrate a plot recipe in MLJ.jl is not documented, but is sketched in this comment.

So, can we somehow put this together to arrange that a workflow like this generates a plot of a decision tree?

edited again (x2):

using MLJBase
using Plots                 # <---- added in edit
import MLJDecisionTreeInterface
tree = MLJDecisionTreeInterface.DecisionTreeClassifier()
X, y = @load_iris
mach = machine(tree, X, y) |> fit!
plot(mach, 0.8, 0.7; size = (1400,600)))   # <---- added in edit

Note: It used to be that you made RecipesBase.jl your dependency, to avoid a full Plots.jl dependency. But now the recipes live in Plots.jl and you are expected to make Plots.jl a weak dependency. You can see an example of this here.

adarshpalaskar1 commented 6 months ago

Hello, I went through the RecipesBase documentation and needed some help understanding the plot recipe's integration. I had some questions:

  1. Should I add the code for the recipe in the MLJDecisionTree.jl file itself or somewhere else? If yes, should I use the example code you mentioned above for plotting directly in the recipe? (I am unable to convert the code that works for the DecisionTreeClassifier model in the example for the MLJ machine).

  2. I cannot pass the Machine in the recipe argument, as it's not a part of the current dependencies(If writing the recipe in the MLJDecisionTree.jl file). What do you think should be done here?

Also, please let me know if these questions make sense or if I'm thinking in the wrong direction😅

ablaom commented 6 months ago

I've looked into this a bit further. I have an idea how to do it but it's a bit involved. The first step is to replace the current fitresult output of the fit methods for DecisionTreeClassifier and DecisionTreeRegressor models with wrapped versions. We need this because we are going to overload Plots.plot(fitresult, ...) for appropriate fitresult types.

So, we create a new struct

struct DecisionTreeClassifierFitResult{T,C,I}
    tree::T
    classes_seen::C
    integers_seen::I
    features::Vector{Symbol}
end

and instead of fit(::DecisionTreeClassifier, ...) returning fitresult = (tree, classes_seen, integers_seen, features) we will return DecisionTreeClassifierFitResult(tree, classes_seen, integers_seen, features).

We will have to modify the definition of predict(::DecisionTreeClassifier, fitresult,...), fitted_parameters(::DecisionTreeClassifier, fitresult) and feature_importances(::TreeModel, ...) accordingly, so that they first unwrap the fitresult.

We do something similar for DecisionTreeRegressor, whose fitresult has a different form.

We should be careful that none of these changes breaks anything. Since fitresult is private (public access is through the fitted_params method we are fixing) this should not be a problem.

@adarshpalaskar1 You want to have a go at a PR to do this internal wrapping?

adarshpalaskar1 commented 6 months ago

Sorry for the delayed response. I went through the implementation steps you provided, and I am eager to work on a PR.

I added the above mentioned changes for the DecisionTreeClassifier, but I am facing an issue while plotting:

In the above mentioned example, (https://github.com/JuliaAI/TreeRecipe.jl/blob/master/examples/DecisionTree_iris.jl), we have:

julia> typeof(dtree)
Node{Float64, String}

wrapped tree in example:

julia> typeof(wt)
InfoNode{Float64, String}



Where as in case of DecisionTreeInterface, we have:

julia> typeof(fitted_params(mach).raw_tree.node)
DecisionTree.Node{Float64, UInt32}

_wrapped tree in fittedparams:

julia> typeof(fitted_params(mach).tree)
DecisionTree.InfoNode{Float64, UInt32}

Due to which we get:

julia> plot(fitted_params(mach).tree)
ERROR: Cannot convert DecisionTree.InfoNode{Float64, UInt32} to series data for plotting

and similarly, after adding a recipe for wrapping:

julia> plot(mach)
ERROR: Cannot convert DecisionTree.InfoNode{Float64, UInt32} to series data for plotting

How can I handle this mismatch in the datatypes?