parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.89k stars 332 forks source link

tfdf.keras.CartModel support? #312

Open Jeff2722 opened 7 months ago

Jeff2722 commented 7 months ago

Hi, thanks for creating this great package, I'm interested in using it for some research just have a question - I was originally using DecisionTreeClassificationModel from sklearn, but I have a lot of string categorical features and saw the package tensorflow_decision_forests did not require any encoding of these features.

So I tried visualizing a CART model from tfdf.keras.CartModel, but I saw only GradientBoostedTreesModel and RandomForestModel from tensorflow_decision_forests.keras are supported in dtreeviz.model().

Is there any way to view a tree created by tfdf.keras.CartModel with dtreeviz.model()? Is there maybe some way to trick it into considering it as a single tree from tensorflow_decision_forests.keras.RandomForestModel? (I might be way off on that though)

edit: for now I've been using a workaround which is to just force the RandomForestModel to only create 1 CART tree using all possible features and all training examples with no sampling: tfdf.keras.RandomForestModel(bootstrap_training_dataset = False, categorical_algorithm = "CART",num_candidate_attributes_ratio = 1.0)

Thanks

tlapusan commented 7 months ago

Hi Jeff, thanks for your feedback.

indeed, for TF we are supporting right now only ["tensorflow_decision_forests.keras.RandomForestModel", "tensorflow_decision_forests.keras.GradientBoostedTreesModel"]. There is also a check for this (I assume you got an error when trying to use CartModel).

If CartModel will have the same interface as the others, it should work, but there is a need to change that hard coded check: shadow_decision_tree.py, line 461.

If you can do this would be awesome :) and a PR will be more than appreciated if it will work.

Thanks, Tudor