parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.97k stars 335 forks source link

TypeError: list indices must be integers or slices, not numpy.float64 #304

Open MohammedAlsayed opened 1 year ago

MohammedAlsayed commented 1 year ago

Hi Guys, thanks for creating this amazing library.

I'm using XGBoost Classifier in sklearn library, and I'm getting this error, because my target column is numpy.float64 type though I only have 1.0 and 0.0 as a binary target.

if we cascade the column into type int, the problem will be resolved. I can pull a request to solve the problem.

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/trees.py:602, in DTreeVizAPI.view(self, precision, orientation, instance_orientation, show_root_edge_labels, show_node_labels, show_just_path, fancy, histtype, leaftype, highlight_path, x, max_X_features_LR, max_X_features_TD, depth_range_to_display, label_fontsize, ticks_fontsize, fontname, title, title_fontsize, colors, scale)
    600         continue
    601 if self.shadow_tree.is_classifier():
--> 602     _class_leaf_viz(node, colors=color_values,
    603                     filename=os.path.join(tmp, f"leaf{node.id}_{os.getpid()}.svg"),
    604                     graph_colors=colors,
    605                     fontname=fontname,
    606                     leaftype=leaftype)
    607     leaves.append(class_leaf_node(node))
    608 else:
    609     # for now, always gen leaf

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/trees.py:1256, in _class_leaf_viz(node, colors, filename, graph_colors, fontname, leaftype)
   1253 size = min(size, maxsize)
   1255 # we visually need n=1 and n=9 to appear different but diff between 300 and 400 is no big deal
-> 1256 counts = node.class_counts()
   1257 prediction = node.prediction_name()
   1259 # when using another dataset than the training dataset, some leaves could have 0 samples.
   1260 # Trying to make a pie chart will raise some deprecation

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/models/shadow_decision_tree.py:578, in ShadowDecTreeNode.class_counts(self)
    576 if self.isclassifier():
    577     if self.shadow_tree.get_class_weight() is None:
--> 578         return np.array(np.round(self.shadow_tree.get_node_nsamples_by_class(self.id)), dtype=int)
    579     else:
    580         return np.round(
    581             self.shadow_tree.get_node_nsamples_by_class(self.id) / self.shadow_tree.get_class_weights()).astype(
    582             int)

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/models/xgb_decision_tree.py:202, in ShadowXGBDTree.get_node_nsamples_by_class(self, id)
    200 all_nodes = self.internal + self.leaves
    201 if self.is_classifier():
--> 202     node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
    203     return node_value[0]

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/models/xgb_decision_tree.py:202, in <listcomp>(.0)
    200 all_nodes = self.internal + self.leaves
    201 if self.is_classifier():
--> 202     node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
    203     return node_value[0]

File ~/miniconda3/envs/bmt/lib/python3.11/site-packages/dtreeviz/models/shadow_decision_tree.py:528, in ShadowDecTreeNode.n_sample_classes(self)
    525 unique, counts = np.unique(node_y_data, return_counts=True)
    527 for i in range(len(unique)):
--> 528     node_values[unique[i]] = counts[i]
    530 return node_values

TypeError: list indices must be integers or slices, not numpy.float64
tlapusan commented 1 year ago

Hi @MohammedAlsayed,

thanks for your feedback and appreciation for the library. In the case we have to work with a binary classification problem, by default the classes should be converter to 0 and 1 integer values from the client side.

Not sure if we should add this conversion into the library.

Thanks, Tudor.

MohammedAlsayed commented 1 year ago

Hi @tlapusan,

Thanks for your reply.

I never faced this issue with other libraries like sklearn, or xgboost in binary classification even though my target is type float. I guess this library should be consistent with other libraries as well. It doesn't make sense that it works perfectly in other libraries then crashes here.

However, if there is no desire to add this feature, then the error message should be more precise than what it shows. Since, it took me time to know what is the issue.

tlapusan commented 1 year ago

Indeed @MohammedAlsayed, the error message should be more precise. If you can make a PR for this and display a more relevant error message it would help.

MohammedAlsayed commented 1 year ago

@tlapusan I think I need to be added as a collaborator to be able to pull a request. Shall I pull on master or dev branch?

tlapusan commented 1 year ago

Can you describe a little more what do you mean by collaborator ? anyone should have permission to make the PR. Yes, into dev branch. thanks.