parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.94k stars 331 forks source link

DTreeViz crashes, if descision tree was built with objects, that are interpretable as numbers. #317

Open lgi1sgm opened 7 months ago

lgi1sgm commented 7 months ago

sklearn.tree.DecisionTreeClassifer casts the parameter X into dtype=np.float23 (see Documentation), therefore it works with the data provided in the example.

But DTreeViz does not and crashes in a call of np.linspace in the function get_split_node_heights().

Example

#!/usr/bin/env python

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text
import dtreeviz

X = pd.DataFrame({'feature_1' : [10, 2, 5], 'feature_2': ['1', '1', '4']})
y = pd.Series([0, 0, 1], name='label')

# X, y = make_blobs(n_samples=10, n_features=2, centers=3)

d = DecisionTreeClassifier()
d.fit(X, y)

print(export_text(d))

dtreeviz_model = dtreeviz.model(d, X_train=X, y_train=y)

dtreeviz_render = dtreeviz_model.view()  # this will crash