csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.38k stars 122 forks source link

HSTreeClassifierCV ignore the feature name #192

Closed FBoyang closed 10 months ago

FBoyang commented 10 months ago

Dear Developers:

I am using HSTreeClassifierCV for some tasks. I noticed that the feature names won't be appropriately accepted by the model. For example, when running the example you provided:

from sklearn.model_selection import train_test_split
from imodels import get_clean_dataset, HSTreeClassifierCV # import any imodels model here

# prepare data (a sample clinical dataset)
X, y, feature_names = get_clean_dataset('csi_pecarn_pred')
X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=42)

# fit the model
model = HSTreeClassifierCV(max_leaf_nodes=4)  # initialize a tree model and specify only 4 leaf nodes
model.fit(X_train, y_train, feature_names=feature_names)   # fit model
preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)
print(model) # print the model

I got:

> ------------------------------
> Decision Tree with Hierarchical Shrinkage
>   Prediction is made by looking at the value in the appropriate leaf of the tree
> ------------------------------
|--- feature_14 <= 0.50
|   |--- feature_4 <= 0.50
|   |   |--- feature_18 <= 0.50
|   |   |   |--- weights: [0.90, 0.10] class: 0.0
|   |   |--- feature_18 >  0.50
|   |   |   |--- weights: [0.70, 0.30] class: 0.0
|   |--- feature_4 >  0.50
|   |   |--- weights: [0.32, 0.68] class: 1.0
|--- feature_14 >  0.50
|   |--- weights: [0.58, 0.42] class: 0.0

The actual feature names don't show properly.

csinva commented 10 months ago

Thanks for your interest and raising this issue! Sorry about this, we have just fixed it in commit 6891928, which will be included in the next release. Until then, this issue can be solved by installing from source.

Thanks!