dmlc / xgboost

Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow
https://xgboost.readthedocs.io/en/stable/
Apache License 2.0
26.14k stars 8.71k forks source link

Correctness of visualization #10829

Open Marchlak opened 1 week ago

Marchlak commented 1 week ago

Hello, I'm developing a library for decision tree visualization https://github.com/mljar/supertree and would appreciate feedback on whether my visualization approach for XGBoost is correct. I've compared my library with dtreeviz, and in dtreeviz, the data in the histogram appears to be split at each node according to the feature from the root node (based on my observations). In contrast, my implementation splits the data according to the feature extracted from the respective node in booster.get_dump(). I would greatly appreciate it if you could provide guidance on the correct visualization approach for your library. Code from my compare notebook:

import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

iris = load_iris()
X = iris.data  
y = iris.target  
features = iris.feature_names  
target = 'species'  
class_names = iris.target_names  

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

xgb_classifier = xgb.XGBClassifier(
    objective='multi:softmax',  
    num_class=3,                
    max_depth=5,                
    learning_rate=0.3,          
    n_estimators=50,            
    random_state=42,             
)

xgb_classifier.fit(X_train, y_train)

from xgboost import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(30, 20))  
plot_tree(xgb_classifier, num_trees=20) 
plt.show()

from dtreeviz import model

viz_model = model(
    xgb_classifier.get_booster(),
    X_train=X_train,
    y_train=y_train,
    feature_names=features,
    target_name=target,
    class_names=list(class_names),
    tree_index=5  
)

viz_model.view()

from supertree import SuperTree

st = SuperTree(
    xgb_classifier, 
    X_train, 
    y_train, 
    iris.feature_names, 
    iris.target_names
)
# Visualize the tree
st.show_tree(which_tree=2)
trivialfis commented 1 week ago

Hi, I still need to look into the source code of either project.

the data in the histogram appears to be split at each node according to the feature from the root node

Please help elaborate on what this means. Is there a histogram when plotting a tree? And this histogram has some data in it, and this data can be split by a (node split) feature in the root node?

If you want to compare against xgboost's own plot tree function, you can dump the tree in the dot format, and plot the tree using graphviz yourself.

Marchlak commented 1 week ago

Please help elaborate on what this means. Is there a histogram when plotting a tree? And this histogram has some data in it, and this data can be split by a (node split) feature in the root node? That's basically what I mean. Maybe I'll provide screenshots of the visualization comparison. f0 - sepal length (cm) f1 -sepal width (cm) f2 - petal length (cm) f3-petal width (cm) graphiz image dtreeviz image My library supertree image

trivialfis commented 1 week ago

Ah, that's much clearer! Thank you for sharing. I will look into it after sorting out some of the on going work here