csinva / imodels

Interpretable ML package 🔍 for concise, transparent, and accurate predictive modeling (sklearn-compatible).
https://csinva.io/imodels
MIT License
1.4k stars 124 forks source link

Handling `NaN` inputs #213

Open ciberger opened 3 weeks ago

ciberger commented 3 weeks ago

Hi! Thanks for this fantastic package. I'm struggling to find a solution that would handle NaNs. I'm passing a catboost model as an estimator to HSTreeClassifier, which correctly handles missing values.

from catboost import CatBoostClassifier
from imodels import HSTreeClassifier

clf = CatBoostClassifier()
model = HSTreeClassifier(estimator_=clf)
model = model.fit(X_train, y_train)

Error message

ValueError: Input contains NaN
File <command-4108227421121860>, line 6
      4 clf = CatBoostClassifier()
      5 model = HSTreeClassifier(estimator_=clf)
----> 6 model = model.fit(X_train, y_train, feature_names=_FEATURES)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-23c4ea58-433a-40f0-89be-b3d953b89efe/lib/python3.11/site-packages/imodels/tree/hierarchical_shrinkage.py:82, in HSTree.fit(self, X, y, sample_weight, *args, **kwargs)
     78 def fit(self, X, y, sample_weight=None, *args, **kwargs):
     79     # remove feature_names if it exists (note: only works as keyword-arg)
     80     # None returned if not passed
     81     feature_names = kwargs.pop("feature_names", None)
---> 82     X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
     83     if feature_names is not None:
     84         self.feature_names = feature_names
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-23c4ea58-433a-40f0-89be-b3d953b89efe/lib/python3.11/site-packages/imodels/util/arguments.py:26, in check_fit_arguments(model, X, y, feature_names)
     24 if scipy.sparse.issparse(X):
     25     X = X.toarray()
---> 26 X, y = check_X_y(X, y)
     27 _, model.n_features_in_ = X.shape
     28 assert len(model.feature_names_) == model.n_features_in_, 'feature_names should be same size as X.shape[1]'
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:1147, in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)
   1142         estimator_name = _check_estimator_name(estimator)
   1143     raise ValueError(
   1144         f"{estimator_name} requires y to be passed, but the target y is None"
   1145     )
-> 1147 X = check_array(
   1148     X,
   1149     accept_sparse=accept_sparse,
   1150     accept_large_sparse=accept_large_sparse,
   1151     dtype=dtype,
   1152     order=order,
   1153     copy=copy,
   1154     force_all_finite=force_all_finite,
   1155     ensure_2d=ensure_2d,
   1156     allow_nd=allow_nd,
   1157     ensure_min_samples=ensure_min_samples,
   1158     ensure_min_features=ensure_min_features,
   1159     estimator=estimator,
   1160     input_name="X",
   1161 )
   1163 y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric, estimator=estimator)
   1165 check_consistent_length(X, y)
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:959, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
    953         raise ValueError(
    954             "Found array with dim %d. %s expected <= 2."
    955             % (array.ndim, estimator_name)
    956         )
    958     if force_all_finite:
--> 959         _assert_all_finite(
    960             array,
    961             input_name=input_name,
    962             estimator_name=estimator_name,
    963             allow_nan=force_all_finite == "allow-nan",
    964         )
    966 if ensure_min_samples > 0:
    967     n_samples = _num_samples(array)
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:109, in _assert_all_finite(X, allow_nan, msg_dtype, estimator_name, input_name)
    107 if X.dtype == np.dtype("object") and not allow_nan:
    108     if _object_dtype_isnan(X).any():
--> 109         raise ValueError("Input contains NaN")
    111 # We need only consider float arrays, hence can early return for all else.
    112 if not xp.isdtype(X.dtype, ("real floating", "complex floating")):