intel / scikit-learn-intelex

Intel(R) Extension for Scikit-learn is a seamless way to speed up your Scikit-learn application
https://intel.github.io/scikit-learn-intelex/
Apache License 2.0
1.21k stars 173 forks source link

Cannot visualize a tree plot with ExtraTrees and Randomforest classifiers #1919

Open YoochanMyung opened 2 months ago

YoochanMyung commented 2 months ago

Describe the bug a function export_graphviz() returns ValueError: cannot convert float NaN to integer error on ExtraTrees and RandomForest classifier algorithms after Intelex patching.

python3.10/site-packages/sklearn/tree/_export.py:258, in <listcomp>(.0)
    254     alpha = (value - self.colors["bounds"][0]) / (
    255         self.colors["bounds"][1] - self.colors["bounds"][0]
    256     )
    257 # compute the color as alpha against white
--> 258 color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
    259 # Return html color code in #RRGGBB format
    260 return "#%2x%2x%2x" % tuple(color)

ValueError: cannot convert float NaN to integer

The original scikit-learn returns [0.0 1.0] for the classifier.estimators_[0].classes_ but after patching Intelex, it returns 0. Maybe this is linked to the following code?

https://github.com/intel/scikit-learn-intelex/blob/01def265ba59d7d4e1eb2e5944d938e274d1bde8/sklearnex/ensemble/_forest.py#L478

To Reproduce

from sklearnex import patch_sklearn
patch_sklearn()

from sklearn.ensemble import ExtraTreesClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# Load the dataset
data = load_breast_cancer()
X = data['data']
y = data['target']
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.3, random_state=1)

# Train the model
clf = ExtraTreesClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

from sklearn.tree import export_graphviz
import graphviz

# Export a single tree from the forest
tree = clf.estimators_[0]

for each in clf.estimators_:
    print(each.classes_)

dot_data = export_graphviz(tree, out_file=None, 
                           feature_names=data.feature_names,  
                           class_names=data.target_names,
                           filled=True, rounded=True,  
                           special_characters=True)  
graph = graphviz.Source(dot_data)  
graph.render("extratree")  # Saves the tree as a .pdf file

# Display the tree
graph

Expected behavior Print [0 1] and show a tree plot.

Output/Screenshots

Environment:

icfaust commented 2 months ago

Hey @YoochanMyung . I wanted to let you know that I am working on your issues, and will write you in the next days.