parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.89k stars 333 forks source link

Catvar support for tf-df #261

Closed parrt closed 1 year ago

parrt commented 1 year ago

@tlapusan please take a look at this branch that I started for you. If you run testing/tf_df_catvars.py, it should show you a classifier tree with bar chart for leaves not pie charts. Pie chart leaves crash at the moment.

I think in order to support string categorical variables we need to do the following:

  1. Detect and convert strings to integers in ShadowTensorflowTree.__init__ via model(); record the string to integer mapping for any categorical columns. I created and filled self.catvar_maps.
  2. For now, we can simply not show a wedge for any catvar; this will probably work for both classifiers and regressors
  3. There's no such thing as a "split value" for categoricals in the general case as it might be doing set membership in tensor flow trees. Per the previous bullet point, I have simply detected this and turned off split wedge display for strings.
  4. We currently do our own walking of the tree in order to compute paths from the root to leaves, which means we do not properly handle categoricals for tensorflow. We need to ask the node for its set membership or better yet ask the underlying tf node to make the decision for us.
  5. Labels on tree edges can only be less than or greater than; at the very least we should not turn that on beneath a categorical test for the moment.

Feel free to just cut and paste from this and put into your own branch. Not sure how to add you to this branch.

I put a ref to @Tudor add a place in the code where we are calling node.splits() that likely needs attention as well.

The good news is that this first example works who is very minimal changes. Pie charts crash and I'm sure it will crash if we try to Display a path from root to leaf. haha. But, shouldn't be too bad.

tlapusan commented 1 year ago

@parrt I think we could handle categorical with string values without encoding/decoding them into an integer value.

I managed to modify the current code to do this.... (with some hardcoding stuff :)) still need to better understand it)

Screenshot 2023-02-11 at 13 27 07

"2. For now, we can simply not show a wedge for any catvar; this will probably work for both classifiers and regressors" this is how we are currently displaying the categorical split nodes for classification and reg (have to check for classification why we have those weird values... should be) :

Screenshot 2023-02-11 at 13 40 32 Screenshot 2023-02-11 at 13 40 48

"3. There's no such thing as a "split value" for categoricals in the general case as it might be doing set membership in tensor flow trees. Per the previous bullet point, I have simply detected this and turned off split wedge display for strings." Indeed, for categorical splits we are using the 'x in get_node_split(node_id)' where get_node_split is a set of values. So we handle categorical splits in dtreeviz :) pls check https://github.com/parrt/dtreeviz/blob/5c1cc6d251b0447249e5211687bfe0db11bc7c9b/dtreeviz/models/lightgbm_decision_tree.py#L223

"4. We currently do our own walking of the tree in order to compute paths from the root to leaves, which means we do not properly handle categoricals for tensorflow. ..." We do properly handle the categorical splits, with condition to be integer... :)), working to support also string : https://github.com/parrt/dtreeviz/blob/5c1cc6d251b0447249e5211687bfe0db11bc7c9b/dtreeviz/models/tensorflow_decision_tree.py#L193

Here is the viz from TF where Sex_label is consider categorial by the TF model.

Screenshot 2023-02-11 at 13 52 10