martin-borkovec / ggparty

147 stars 14 forks source link

Plotting trees from caret #47

Closed rafabelokurows closed 2 years ago

rafabelokurows commented 3 years ago

Hi, first of all, thanks for this package, it works perfectly.

I just have a question which is probably going to be a suggestion: Is there a way to plot trees trained with Caret using ggparty?

I currently have a framework in place using caret to preprocess, train and evaluate models and I know for a fact it supports algorithms "ctree" and "ctree2". It would be amazing if we could plot those trees with ggparty.

Thanks in advance, keep up the good work!

martin-borkovec commented 3 years ago

Thanks, glad you like the package :)

Could you provide an example?

If it's partykit's ctree I think it should actually work

rafabelokurows commented 3 years ago

Sure, here you go.

library(ggparty)  
data(iris)
model <- train(Species~.,data=iris,method="ctree")
ggparty(model) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() 
ggparty(model$finalModel) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar()

image

I'm training a caret model with ctree or ctree2 as methods and unless I'm missing something here, I'm not able to plot either of those objects with ggparty.

martin-borkovec commented 2 years ago

I'm very sorry for the super late response.

I've finally found time to look into it and sadly I don't think it's easily solvable. The problem seems to stem from the fact, that caret implemented the old ctree function from the package "party" which creates objects of the class "BinaryTree".

However, ggparty only works with objects of the class "party", which is the class of the output of the new ctree function in the package "partykit".

Yes, a bit confusing...

So, If this is still relevant for you, I think the easiest way would be to write your own method for train() in caret using the new ctree function. It should be possible following this tutorial http://topepo.github.io/caret/using-your-own-model-in-train.html

you can look at how the "ctree" method for caret::train() is defined by running caret:::getModelInfo("ctree"). I hope that you can use much of that code and it should only take some small adjustments. Or you can ask the caret developers to implement the new partykit::ctree function.

My understanding is, that party::ctree() is rather outdated, and therefore I don't think I will implement functionality for party::BinaryTree objects in ggparty.