parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.95k stars 331 forks source link

Dtreeviz recognize XGBoost classification model as regression model #129

Open HyukdongKim opened 3 years ago

HyukdongKim commented 3 years ago

(I got help from a google translator. If there is any expression that looks rude, I apologize in advance.)

First of all, thank you for making a great library!

I tried to visualize XGBoost classification model. But dtreeviz made result about regression.

This is my code.

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

import xgboost as xgb

dtrain = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
dtest = xgb.DMatrix(X_test, label=y_test, enable_categorical=True)

params = {
    'max_depth': 3,
    'eta': 0.3, 
    'objective': 'multi:softprob',
    'eval_metric': 'mlogloss',
    'num_class': 3, 
}
num_round = 100  # the number of training iterations

bst = xgb.train(params=params, dtrain=dtrain, num_boost_round=num_round)

from dtreeviz import trees

import graphviz

from dtreeviz.models.shadow_decision_tree import ShadowDecTree
from dtreeviz.models.xgb_decision_tree import ShadowXGBDTree

trees.dtreeviz(tree_model = bst, 
               x_data = X_train, 
               y_data = y_train, 
               target_name = 'class',
               feature_names = ['f0','f1','f2','f3'], 
               histtype = 'barstacked',
               tree_index=1,
               class_names = list(iris.target_names), 
              )

I expected distribution graph and pie chart, but my code generated scatter chart.

image

Is there anything I've done wrong? Or is there a temporary solution?

[ Dependencies ]

tlapusan commented 3 years ago

hi @HyukdongKim, it's an issue from the library. thanks for creating an issue about it.

it's caused by how dtreeviz interprets the tree to be classifier or regressor. Fixed it, but it seems there are still few other issues. I will come back when I will fix them.