parrt / dtreeviz

A python library for decision tree visualization and model interpretation.
MIT License
2.93k stars 331 forks source link

ShadowLightGBMTree doesn't show some trees #192

Open baligoyem opened 2 years ago

baligoyem commented 2 years ago

When I try to plot any trees of my lgbm model, sometimes an error occurs and it doesn't show the tree. For instance, I can't see the tree, which has "0" index number. But when I type the index as 1 for this model, there is no problem and it shows the tree which has "1" index number.

Moreover, "try-except pass" structure doesn't work in this case and I can't ignore the error message below.

light_dtree = ShadowLightGBMTree(gbm_dtreeviz, tree_index=0, x_data=X_train_tree, y_data=y_train_tree, feature_names=list(X_tree), target_name="quality_status", class_names=target)
dtreeviz(light_dtree, scale=1.75)`
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [4], in <module>
      1 light_dtree = ShadowLightGBMTree(gbm_dtreeviz, tree_index=0, x_data=X_train_tree, y_data=y_train_tree, feature_names=list(X_tree), target_name="quality_status", class_names=target)
----> 2 dtreeviz(light_dtree, scale=1.75)

File /opt/tljh/user/lib/python3.9/site-packages/dtreeviz/trees.py:842, in dtreeviz(tree_model, x_data, y_data, feature_names, target_name, class_names, tree_index, precision, orientation, instance_orientation, show_root_edge_labels, show_node_labels, show_just_path, fancy, histtype, highlight_path, X, max_X_features_LR, max_X_features_TD, depth_range_to_display, label_fontsize, ticks_fontsize, fontname, title, title_fontsize, colors, scale)
    840         continue
    841 if shadow_tree.is_classifier():
--> 842     class_leaf_viz(node, colors=color_values,
    843                    filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
    844                    graph_colors=colors,
    845                    fontname=fontname)
    846     leaves.append(class_leaf_node(node))
    847 else:
    848     # for now, always gen leaf

File /opt/tljh/user/lib/python3.9/site-packages/dtreeviz/trees.py:1102, in class_leaf_viz(node, colors, filename, graph_colors, fontname)
   1100 counts = node.class_counts()
   1101 prediction = node.prediction_name()
-> 1102 draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
   1103               graph_colors=graph_colors, fontname=fontname)

File /opt/tljh/user/lib/python3.9/site-packages/dtreeviz/trees.py:1323, in draw_piechart(counts, size, colors, filename, label, fontname, graph_colors)
   1320 ax.set_ylim(0, size - 10 * tweak)
   1321 # frame=True needed for some reason to fit pie properly (ugh)
   1322 # had to tweak the crap out of this to get tight box around piechart :(
-> 1323 wedges, _ = ax.pie(counts, center=(size / 2 - 6 * tweak, size / 2 - 6 * tweak), radius=size / 2, colors=colors,
   1324                    shadow=False, frame=True)
   1325 for w in wedges:
   1326     w.set_linewidth(.5)

File ~/.local/lib/python3.9/site-packages/matplotlib/__init__.py:1414, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1411 @functools.wraps(func)
   1412 def inner(ax, *args, data=None, **kwargs):
   1413     if data is None:
-> 1414         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1416     bound = new_sig.bind(ax, *args, **kwargs)
   1417     auto_label = (bound.arguments.get(label_namer)
   1418                   or bound.kwargs.get(label_namer))

File ~/.local/lib/python3.9/site-packages/matplotlib/axes/_axes.py:3098, in Axes.pie(self, x, explode, labels, colors, autopct, pctdistance, shadow, labeldistance, startangle, radius, counterclock, wedgeprops, textprops, center, frame, rotatelabels, normalize)
   3095 x += expl * math.cos(thetam)
   3096 y += expl * math.sin(thetam)
-> 3098 w = mpatches.Wedge((x, y), radius, 360. * min(theta1, theta2),
   3099                    360. * max(theta1, theta2),
   3100                    facecolor=get_next_color(),
   3101                    clip_on=False,
   3102                    label=label)
   3103 w.set(**wedgeprops)
   3104 slices.append(w)

File ~/.local/lib/python3.9/site-packages/matplotlib/patches.py:1164, in Wedge.__init__(self, center, r, theta1, theta2, width, **kwargs)
   1162 self.theta1, self.theta2 = theta1, theta2
   1163 self._patch_transform = transforms.IdentityTransform()
-> 1164 self._recompute_path()

File ~/.local/lib/python3.9/site-packages/matplotlib/patches.py:1176, in Wedge._recompute_path(self)
   1173     connector = Path.LINETO
   1175 # Form the outer ring
-> 1176 arc = Path.arc(theta1, theta2)
   1178 if self.width is not None:
   1179     # Partial annulus needs to draw the outer ring
   1180     # followed by a reversed and scaled inner ring
   1181     v1 = arc.vertices

File ~/.local/lib/python3.9/site-packages/matplotlib/path.py:947, in Path.arc(cls, theta1, theta2, n, is_wedge)
    945 # number of curve segments to make
    946 if n is None:
--> 947     n = int(2 ** np.ceil((eta2 - eta1) / halfpi))
    948 if n < 1:
    949     raise ValueError("n must be >= 1 or None")

ValueError: cannot convert float NaN to integer
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /opt/tljh/user/lib/python3.9/site-packages/IPython/core/formatters.py:339, in BaseFormatter.__call__(self, obj)
    337     pass
    338 else:
--> 339     return printer(obj)
    340 # Finally look for special method names
    341 method = get_real_method(obj, self.print_method)

File /opt/tljh/user/lib/python3.9/site-packages/IPython/core/pylabtools.py:151, in print_figure(fig, fmt, bbox_inches, base64, **kwargs)
    148     from matplotlib.backend_bases import FigureCanvasBase
    149     FigureCanvasBase(fig)
--> 151 fig.canvas.print_figure(bytes_io, **kw)
    152 data = bytes_io.getvalue()
    153 if fmt == 'svg':

File ~/.local/lib/python3.9/site-packages/matplotlib/backend_bases.py:2299, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2297 if bbox_inches:
   2298     if bbox_inches == "tight":
-> 2299         bbox_inches = self.figure.get_tightbbox(
   2300             renderer, bbox_extra_artists=bbox_extra_artists)
   2301         if pad_inches is None:
   2302             pad_inches = rcParams['savefig.pad_inches']

File ~/.local/lib/python3.9/site-packages/matplotlib/figure.py:1684, in FigureBase.get_tightbbox(self, renderer, bbox_extra_artists)
   1681     artists = bbox_extra_artists
   1683 for a in artists:
-> 1684     bbox = a.get_tightbbox(renderer)
   1685     if bbox is not None and (bbox.width != 0 or bbox.height != 0):
   1686         bb.append(bbox)

File ~/.local/lib/python3.9/site-packages/matplotlib/axes/_base.py:4675, in _AxesBase.get_tightbbox(self, renderer, call_axes_locator, bbox_extra_artists, for_layout_only)
   4671     if np.all(clip_extent.extents == axbbox.extents):
   4672         # clip extent is inside the Axes bbox so don't check
   4673         # this artist
   4674         continue
-> 4675 bbox = a.get_tightbbox(renderer)
   4676 if (bbox is not None
   4677         and 0 < bbox.width < np.inf
   4678         and 0 < bbox.height < np.inf):
   4679     bb.append(bbox)

File ~/.local/lib/python3.9/site-packages/matplotlib/artist.py:356, in Artist.get_tightbbox(self, renderer)
    341 def get_tightbbox(self, renderer):
    342     """
    343     Like `.Artist.get_window_extent`, but includes any clipping.
    344 
   (...)
    354         The enclosing bounding box (in figure pixel coordinates).
    355     """
--> 356     bbox = self.get_window_extent(renderer)
    357     if self.get_clip_on():
    358         clip_box = self.get_clip_box()

File ~/.local/lib/python3.9/site-packages/matplotlib/patches.py:629, in Patch.get_window_extent(self, renderer)
    628 def get_window_extent(self, renderer=None):
--> 629     return self.get_path().get_extents(self.get_transform())

File ~/.local/lib/python3.9/site-packages/matplotlib/path.py:631, in Path.get_extents(self, transform, **kwargs)
    629         # as can the ends of the curve
    630         xys.append(curve([0, *dzeros, 1]))
--> 631     xys = np.concatenate(xys)
    632 if len(xys):
    633     return Bbox([xys.min(axis=0), xys.max(axis=0)])

File <__array_function__ internals>:5, in concatenate(*args, **kwargs)

ValueError: need at least one array to concatenate
<Figure size 56.88x56.88 with 1 Axes>
parrt commented 2 years ago

I'm not sure but it could be something about the data you are passing again. Is there a way for you to come up with a trivial data set that causes problems so we can test it with a self-contained test?

tlapusan commented 2 years ago

Hi @baligoyem,

I see two errors in your stacktrace : "ValueError: cannot convert float NaN to integer" and "ValueError: need at least one array to concatenate". They seem to be related to your training set. If you can provide us a google colab notebook or any other sharable notebook, we could better investigate it.

Another option would be to run the lightgbm notebook from this repo and to check if you have the same issue. Maybe you could understand from where is the issue.

baligoyem commented 2 years ago

Hi @tlapusan and @parrt ,

Sorry for my late reply. I will share the data and my notebook soon. First thing I need to do is masking the data and then waiting for getting the same error again :)

tlapusan commented 2 years ago

You could also run the lightgbm notebook from this repo to check if you have the same issue or not

baligoyem commented 2 years ago

Actually, I've already used this repo. Thanks for your suggestion.

wukan1986 commented 1 year ago

Hi @parrt @tlapusan

I meet same problem!

  File "D:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\matplotlib\path.py", line 947, in arc
    n = int(2 ** np.ceil((eta2 - eta1) / halfpi))
ValueError: cannot convert float NaN to integer

I debug it, in line 1097 in trees.py _class_leaf_viz

counts = node.class_counts()

the counts is [0, 0], so get the error.

my data is too big, sorry i can not upload it. but you can test it by this.

counts = node.class_counts() 
counts[:] = 0

Thank you very much!

parrt commented 1 year ago

Thanks @wukan1986 could you try cutting your data down? Also please upgrade to dtreeviz 2.0. thanks!

wukan1986 commented 1 year ago

Hi @parrt

I install dtreeviz yesterday, is 2.0 I cut the data from 1GB to 44KB. delete a lot of rows and features.

please download code and data from this train.zip

d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\matplotlib\axes\_axes.py:3056: RuntimeWarning: invalid value encountered in divide

ValueError                                Traceback (most recent call last)
d:\GitHub\my_quant\tests\train.py in line 2
      79 # %%
----> 80 v = viz_model.view(fancy=True)
      81 v.show()

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\dtreeviz\trees.py:555, in DTreeVizAPI.view(self, precision, orientation, instance_orientation, show_root_edge_labels, show_node_labels, show_just_path, fancy, histtype, highlight_path, x, max_X_features_LR, max_X_features_TD, depth_range_to_display, label_fontsize, ticks_fontsize, fontname, title, title_fontsize, colors, scale)
    553         continue
    554 if self.shadow_tree.is_classifier():
--> 555     _class_leaf_viz(node, colors=color_values,
    556                     filename=f"{tmp}/leaf{node.id}_{os.getpid()}.svg",
    557                     graph_colors=colors,
    558                     fontname=fontname)
    559     leaves.append(class_leaf_node(node))
    560 else:
    561     # for now, always gen leaf

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\dtreeviz\trees.py:1099, in _class_leaf_viz(node, colors, filename, graph_colors, fontname)
   1097 counts = node.class_counts()
   1098 prediction = node.prediction_name()
-> 1099 _draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
   1100                graph_colors=graph_colors, fontname=fontname)

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\dtreeviz\trees.py:1318, in _draw_piechart(counts, size, colors, filename, label, fontname, graph_colors)
...
--> 947     n = int(2 ** np.ceil((eta2 - eta1) / halfpi))
    948 if n < 1:
    949     raise ValueError("n must be >= 1 or None")

ValueError: cannot convert float NaN to integer

ValueError                                Traceback (most recent call last)
File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\IPython\core\formatters.py:338, in BaseFormatter.__call__(self, obj)
    336     pass
    337 else:
--> 338     return printer(obj)
    339 # Finally look for special method names
    340 method = get_real_method(obj, self.print_method)

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\IPython\core\pylabtools.py:152, in print_figure(fig, fmt, bbox_inches, base64, **kwargs)
    149     from matplotlib.backend_bases import FigureCanvasBase
    150     FigureCanvasBase(fig)
--> 152 fig.canvas.print_figure(bytes_io, **kw)
    153 data = bytes_io.getvalue()
    154 if fmt == 'svg':

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\matplotlib\backend_bases.py:2299, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2297 if bbox_inches:
   2298     if bbox_inches == "tight":
-> 2299         bbox_inches = self.figure.get_tightbbox(
   2300             renderer, bbox_extra_artists=bbox_extra_artists)
   2301         if pad_inches is None:
   2302             pad_inches = rcParams['savefig.pad_inches']

File d:\Users\Kan\miniconda3\envs\torch1\lib\site-packages\matplotlib\figure.py:1684, in FigureBase.get_tightbbox(self, renderer, bbox_extra_artists)
...
    633     return Bbox([xys.min(axis=0), xys.max(axis=0)])

File <__array_function__ internals>:180, in concatenate(*args, **kwargs)

ValueError: need at least one array to concatenate
<Figure size 37x37 with 1 Axes>

Thank you very much

tlapusan commented 1 year ago

Hi @wukan1986, thanks for sending the code and data. It helps for debugging.

I tried you code and it works for me

Screenshot 2023-01-01 at 16 44 27

dtreeviz 2.0 lightgbm==3.3.3

tlapusan commented 1 year ago

I see the error is generated somewhere inside matplotlib ... I'm having matplotlib==3.4.2

wukan1986 commented 1 year ago

my matplotlib is 3.5.3

I will try it in google colab notebook

wukan1986 commented 1 year ago

Hi @tlapusan @parrt

I update matplotlib to 3.6.2, get same error

https://colab.research.google.com/drive/1EmihQj9WpkY3dLHH9CSq5h7OGtlDQ0dk?usp=sharing

parrt commented 1 year ago

I wonder if this is an issue of a variety of different package versions on colab vs jupyter lab.

parrt commented 1 year ago

@tlapusan his version is lightgbm-3.3.3; mine is 3.2.1.

Running his .py i get:

beast:~/Downloads/train $ python train.py 
Traceback (most recent call last):
  File "/Users/parrt/Downloads/train/train.py", line 57, in <module>
    lgb.log_evaluation(10),
AttributeError: module 'lightgbm' has no attribute 'log_evaluation'

damn and upgrading doesn't work because I get an M1 silicon error on my mac.

tlapusan commented 1 year ago

@tlapusan his version is lightgbm-3.3.3; mine is 3.2.1.

Running his .py i get:

beast:~/Downloads/train $ python train.py 
Traceback (most recent call last):
  File "/Users/parrt/Downloads/train/train.py", line 57, in <module>
    lgb.log_evaluation(10),
AttributeError: module 'lightgbm' has no attribute 'log_evaluation'

damn and upgrading doesn't work because I get an M1 silicon error on my mac.

yes, I needed also to upgrade the lightgbm to solve that error.

wukan1986 commented 1 year ago

@tlapusan his version is lightgbm-3.3.3; mine is 3.2.1.

Running his .py i get:

beast:~/Downloads/train $ python train.py 
Traceback (most recent call last):
  File "/Users/parrt/Downloads/train/train.py", line 57, in <module>
    lgb.log_evaluation(10),
AttributeError: module 'lightgbm' has no attribute 'log_evaluation'

damn and upgrading doesn't work because I get an M1 silicon error on my mac.

log_evaluation can remove, it just print log every 10 rounds