parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.89k stars 332 forks source link

KeyError when using decision_boundaries function #238

Closed 0ptimista closed 1 year ago

0ptimista commented 1 year ago

I have a trained DecisionTreeClassifier model with 2 features. And it is good when using dtreeviz.model() to observe the model.

CleanShot 2023-01-06 at 15 00 12

But when I try decision_boundaries() It's throwing a KeyError and draw only decision boundaries without data points. I want thoses points:

Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?e5f87d4f-5132-477c-b55b-fd80eed7d948)
KeyError                                  Traceback (most recent call last)
Cell In[34], line 1
----> 1 decision_boundaries(
      2     dt_wx,
      3     X_train,
      4     y_train,
      5     feature_names=list(X_train.columns),
      6     target_name='error_rate',
      7     class_names=["OK", "Problem"],
      8 )

File /opt/homebrew/Caskroom/miniconda/base/envs/data-science-py311/lib/python3.11/site-packages/dtreeviz/, in decision_boundaries(model, X, y, ntiles, tile_fraction, binary_threshold, show, feature_names, target_name, class_names, markers, boundary_marker, boundary_markersize, fontsize, fontname, dot_w, yshift, sigma, colors, ranges, figsize, ax)
     97     decision_boundaries_univar(model=model, x=X, y=y,
     98                                ntiles=ntiles,
     99                                binary_threshold=binary_threshold,
    110                                figsize=figsize,
    111                                ax=ax)
    112 elif len(X.shape) == 2 and X.shape[1] == 2:
--> 113     decision_boundaries_bivar(model=model, X=X, y=y,
    114                               ntiles=ntiles, tile_fraction=tile_fraction,
    115                               binary_threshold=binary_threshold,
    116                               show=show,
    117                               feature_names=feature_names, target_name=target_name,
    207                lw=.5)
    208     # Show misclassified markers (can't have alpha per marker so do in 2 calls)
    209     bad_x = x_[class_X_pred[i] != class_values[i],:]


parrt commented 1 year ago

Can you send data + small program? I can debug.

0ptimista commented 1 year ago

Is it ok I send those to your email address on GitHub? Or is there a better way ?

parrt commented 1 year ago

You can probably attach here if they’re not too big but my email is OK as well

0ptimista commented 1 year ago

I tried this on Jupyter Notebook.

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from dtreeviz import decision_boundaries
import dtreeviz

data = pd.read_csv('sample.csv')

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

dt_wx = DecisionTreeClassifier(max_depth=6), y.values)

viz = dtreeviz.model(
    class_names=["OK", "Problem"],

    dt_wx, X_train, y_train,
    class_names=["OK", "Problem"],



Thanks for helping Professor!

tlapusan commented 1 year ago

@parrt just a hint, I made a little debug on the code and the error is generated because:

  1. y class value are [1 2]
  2. color_map is {1: '#FEFEBB', 2: '#a1dab4'}
  3. and at line 202 it try to get the c=color_map[i], where i = 0. The dict color_map doesn't contain the key 0.
parrt commented 1 year ago

I think we make an assumption that class values all start from zero, right?

tlapusan commented 1 year ago

@parrt I guess yes, I am not very familiar with that part of implementation.

parrt commented 1 year ago

OK @0ptimista, the issue is that class labels have to start from zero but the labels in this case are [1,2]. It must be very common to keep everything indexed from zero so for now I'm going to simply add code indicate this is an error.

parrt commented 1 year ago

You can probably do something like y=data['stat']-1

parrt commented 1 year ago

I am adding functionality to emit an error:

Traceback (most recent call last):
  File "/Users/parrt/github/dtreeviz/", line 24, in <module>
  File "/Users/parrt/github/dtreeviz/dtreeviz/", line 478, in view
    raise ValueError("Target label values (for now) must be 0..n-1 for n labels")
ValueError: Target label values (for now) must be 0..n-1 for n labels
0ptimista commented 1 year ago

@parrt @tlapusan I tried to set my class from 0 as sugestted, now I can see those points.

The new ValueError above it is really a good hint, and again, thanks for helping!