csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.36k stars 121 forks source link

GOSDT unable to print tree #76

Closed ja16005 closed 2 years ago

ja16005 commented 2 years ago

I am utilizing the OptimalTreeClassifier model from GOSDT as shown in the repository example itself. But unable to print the tree and is throwing an "AttributeError: 'OptimalTreeClassifier' object has no attribute 'classes_' " error as shown

Screenshot (211) But this is throwing a error internally `--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /tmp/ipykernel_33/809247974.py in ----> 1 test_classification_binary()

/tmp/ipykernel_33/280002779.py in test_classification_binary() 23 # test acc 24 acc_train = np.mean(preds == new_y) ---> 25 print(type(m),m, 'final acc', acc_train) 26 assert acc_train > 0.8, 'acc greater than 0.8'

/opt/conda/lib/python3.7/site-packages/imodels/tree/cart_wrapper.py in str(self) 58 return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True) 59 else: ---> 60 return 'GreedyTree:\n' + export_text(self, show_weights=True) 61 62

/opt/conda/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, kwargs) 70 FutureWarning) 71 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) ---> 72 return f(kwargs) 73 return inner_f 74

/opt/conda/lib/python3.7/site-packages/sklearn/tree/_export.py in export_text(decision_tree, feature_names, max_depth, spacing, decimals, showweights) 872 tree = decisiontree.tree 873 if is_classifier(decision_tree): --> 874 class_names = decisiontree.classes 875 right_child_fmt = "{} {} <= {}\n" 876 left_child_fmt = "{} {} > {}\n"

AttributeError: 'OptimalTreeClassifier' object has no attribute 'classes_' `

csinva commented 2 years ago

Hello! Thanks for bringing this to attention - indeed this is a strange issue; we'll get to fixing it.

In the meantime, you should be able to avoid this issue by installing the gosdt dependency (pip install gosdt) and then re-running.

csinva commented 2 years ago

This should be fixed now!