csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.34k stars 119 forks source link

HSTree sample weight as positional argument is incompatible with scikit API #160

Closed bbstats closed 1 year ago

bbstats commented 1 year ago

This line uses a positional instead of keyword argument to implement sample weight. This causes errors with sklearn-compatible estimators (like CatBoost).

Code that does not work:

import catboost as cb
from imodels import HSTreeRegressorCV

tree = HSTreeRegressorCV(cb.CatBoostRegressor(random_state=42,verbose=0))
tree.fit(X_train, y_train, sample_weight=w_train)

To my eyes, the line should instead be:

self.estimator_ = self.estimator_.fit(X, y, *args, sample_weight=sample_weight, **kwargs)

Subclassing this with my proposed fix solves the problem.

from imodels.tree.hierarchical_shrinkage import HSTree
from imodels.util.arguments import check_fit_arguments
from imodels.util.tree import compute_tree_complexity
from sklearn.tree import DecisionTreeRegressor
from typing import List
import numpy as np
from copy import deepcopy
from sklearn.model_selection import cross_val_score
class ExtHSTree(HSTree):
    def fit(self, X, y, sample_weight=None, *args, **kwargs):
        # remove feature_names if it exists (note: only works as keyword-arg)
        feature_names = kwargs.pop('feature_names', None)  # None returned if not passed
        X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
        self.estimator_ = self.estimator_.fit(X, y, *args, sample_weight=sample_weight, **kwargs) # <<<<<<<<<<<<<<<<
        self._shrink()

        # compute complexity
        if hasattr(self.estimator_, 'tree_'):
            self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
        elif hasattr(self.estimator_, 'estimators_'):
            self.complexity_ = 0
            for i in range(len(self.estimator_.estimators_)):
                t = deepcopy(self.estimator_.estimators_[i])
                if isinstance(t, np.ndarray):
                    assert t.size == 1, 'multiple trees stored under tree_?'
                    t = t[0]
                self.complexity_ += compute_tree_complexity(t.tree_)
        return self

from sklearn.base import RegressorMixin, BaseEstimator

class ExtHSTreeRegressor(ExtHSTree, RegressorMixin):
    ...

class ExtHSTreeRegressorCV(ExtHSTreeRegressor):
    def __init__(self, estimator_: BaseEstimator = None,
                 reg_param_list: List[float] = [0.1, 1, 10, 50, 100, 500],
                 shrinkage_scheme_: str = 'node_based',
                 max_leaf_nodes: int = 20,
                 cv: int = 3, scoring=None, *args, **kwargs):
        """Cross-validation is used to select the best regularization parameter for hierarchical shrinkage.
         Params
        ------
        estimator_
            Sklearn estimator (already initialized).
            If no estimator_ is passed, sklearn decision tree is used
        max_rules
            If estimator is None, then max_leaf_nodes is passed to the default decision tree
        args, kwargs
            Note: args, kwargs are not used but left so that imodels-experiments can still pass redundant args.
        """
        if estimator_ is None:
            estimator_ = DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes)
        super().__init__(estimator_, reg_param=None)
        self.reg_param_list = np.array(reg_param_list)
        self.cv = cv
        self.scoring = scoring
        self.shrinkage_scheme_ = shrinkage_scheme_
        # print('estimator', self.estimator_,
        #       'checks.check_is_fitted(estimator)', checks.check_is_fitted(self.estimator_))
        # if checks.check_is_fitted(self.estimator_):
        #     raise Warning('Passed an already fitted estimator,'
        #                   'but shrinking not applied until fit method is called.')

    def fit(self, X, y, *args, **kwargs):
        self.scores_ = []
        for reg_param in self.reg_param_list:
            est = ExtHSTreeRegressor(deepcopy(self.estimator_), reg_param)
            cv_scores = cross_val_score(est, X, y, cv=self.cv, scoring=self.scoring)
            self.scores_.append(np.mean(cv_scores))
        self.reg_param = self.reg_param_list[np.argmax(self.scores_)]
        super().fit(X=X, y=y, *args, **kwargs)

The below replacement model training does work:

tree = ExtHSTreeRegressorCV(cb.CatBoostRegressor(random_state=42,verbose=0))
tree.fit(X_train, y_train, sample_weight=w_train)
csinva commented 1 year ago

Thanks @bbstats for this catch! You're completely right. Do you want to make a PR with this change?

csinva commented 1 year ago

Just pushed this change in this commit!

Will appear in the next release, thanks again!

bbstats commented 1 year ago

Thanks @csinva !!!