csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.34k stars 119 forks source link

Update hierarchical_shrinkage, fix bugs, change attribute name #183

Closed jckkvs closed 11 months ago

jckkvs commented 11 months ago

In this pull request, I won't be making any changes to the algorithms of each estimator. Instead, it aims to resolve several implementation bugs and inconveniences.

1. Change estimator_ to estimator

In scikit-learn, it is a convention for certain attributes defined after training a model to have an underscore () at the end of their names. The underscored attribute "estimator" refers to the trained model itself, while "estimator" represents the estimator object before training. It helps users distinguish between pre-training and post-training attributes.

2. Change init of HSRegressor, HSClassifier, HSRegressorCV, HSClassifierCV

As mentioned on this website: Developing scikit-learn estimators — scikit-learn 1.3.0 documentation In addition, every keyword argument accepted by init should correspond to an attribute on the instance. Scikit-learn relies on this to find the relevant attributes to set on an estimator when doing model selection.

Instead, it's common to use _validate_estimator to set the initial estimator to use when estimator=None. https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/ensemble/_weight_boosting.py#L1101C1-L1103C80

In the parent class of HSTree, DecisionTreeClassifier is specified as the estimator. However, considering that HSTree can be used for both Regression and Classification tasks, we have made appropriate improvements to accommodate both scenarios. By inheriting BaseEstimator in HSTree, the need for the get_params function has been eliminated.

  1. Change str and repr to avoid bug. In the current code, to print(HSTree) before fitting will raise an error.
    estimator = FIGSRegressor()
    print(estimator)

I have added a code to check whether it has been fitted or not.

I'm not sure if implementing the tree's structure representation in str is the best approach, but I'm deferring to your implementation.

csinva commented 11 months ago

This is wonderful, thank you for cleaning this up!

Is it ready to merge or is it still in flux?

csinva commented 11 months ago

Sorry I merged this but then after some testing realized that the FIGS model was not printing properly...you can see the issues if you run using pytest

jckkvs commented 11 months ago

i wilj add tests about printing bugs