csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.34k stars 119 forks source link

model.plot() doesn't work #204

Closed jacons closed 2 months ago

jacons commented 2 months ago

Hi!,

I was trying to plot the tree in the FIGS model when the following error appeared.

There are some incompatibility issues with sklean? What kind of sklearn version should I adopt?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[47], line 1
----> 1 figs.get_model().plot()

File ~\anaconda3\envs\FIGS_test\lib\site-packages\imodels\tree\figs.py:631, in FIGS.plot(self, cols, feature_names, filename, label, impurity, tree_number, dpi, fig_size)
    629     ax = axs
    630 try:
--> 631     dt = extract_sklearn_tree_from_figs(
    632         self, i if tree_number is None else tree_number, n_classes
    633     )
    634     plot_tree(
    635         dt,
    636         ax=ax,
   (...)
    639         impurity=impurity,
    640     )
    641 except IndexError:

File ~\anaconda3\envs\FIGS_test\lib\site-packages\imodels\tree\viz_utils.py:115, in extract_sklearn_tree_from_figs(figs, tree_num, n_classes, with_leaf_predictions)
    113 tree = Tree(n_features=n_features, n_classes=n_classes_array, n_outputs=n_outputs)
    114 # https://github.com/scikit-learn/scikit-learn/blob/3850935ea610b5231720fdf865c837aeff79ab1b/sklearn/tree/_tree.pyx#L677
--> 115 tree.__setstate__(_state)
    117 # add the tree_ for the dt __setstate__()
    118 # note the trailing underscore also trips the sklearn_is_fitted protections
    119 _state['tree_'] = tree

File sklearn\tree\_tree.pyx:728, in sklearn.tree._tree.Tree.__setstate__()

File sklearn\tree\_tree.pyx:1434, in sklearn.tree._tree._check_node_ndarray()

ValueError: node array from the pickle has an incompatible dtype:
- expected: {'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', 'n_node_samples', 'weighted_n_node_samples', 'missing_go_to_left'], 'formats': ['<i8', '<i8', '<i8', '<f8', '<f8', '<i8', '<f8', 'u1'], 'offsets': [0, 8, 16, 24, 32, 40, 48, 56], 'itemsize': 64}
- got     : [('left_child', '<i8'), ('right_child', '<i8'), ('feature', '<i8'), ('threshold', '<f8'), ('impurity', '<f8'), ('n_node_samples', '<i8'), ('weighted_n_node_samples', '<f8')]

Bye

jacons commented 2 months ago

It worked for me.

https://github.com/yzhao062/pyod/issues/519