koaning / scikit-lego

Extra blocks for scikit-learn pipelines.
https://koaning.github.io/scikit-lego/
MIT License
1.27k stars 117 forks source link

[BUG] ZeroInflatedRegressor not compatible with KNeighborsRegressor #536

Closed brian-l-hand closed 2 years ago

brian-l-hand commented 2 years ago

I am attempting to use ZeroInflatedRegressor with KNeighborsRegressor as the regressor. The KNeighborsRegressor fit() method does not have a sample_weight parameter, which results in the following error:

ZIR_KNN = ZeroInflatedRegressor(
        classifier=DecisionTreeClassifier(max_depth=1),
        regressor=KNeighborsRegressor(n_neighbors=10)
)
ZIR_KNN.fit(X, y)
---------------------------------------------------------------------------
NotFittedError                            Traceback (most recent call last)
C:\ProgramData\Anaconda3\envs\CollegeBenchmarks\lib\site-packages\sklego\meta\zero_inflated_regressor.py in fit(self, X, y, sample_weight)
     95             try:
---> 96                 check_is_fitted(self.regressor)
     97                 self.regressor_ = self.regressor

C:\ProgramData\Anaconda3\envs\CollegeBenchmarks\lib\site-packages\sklearn\utils\validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)
   1207     if not fitted:
-> 1208         raise NotFittedError(msg % {"name": type(estimator).__name__})
   1209 

NotFittedError: This KNeighborsRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_5740/1147313730.py in <module>
----> 1 ZIR_KNN.fit(X, y)

C:\ProgramData\Anaconda3\envs\CollegeBenchmarks\lib\site-packages\sklego\meta\zero_inflated_regressor.py in fit(self, X, y, sample_weight)
     98             except NotFittedError:
     99                 self.regressor_ = clone(self.regressor)
--> 100                 self.regressor_.fit(
    101                     X[non_zero_indices],
    102                     y[non_zero_indices],

TypeError: fit() got an unexpected keyword argument 'sample_weight'
MBrouns commented 2 years ago

That seems like a valid bug indeed! The ZeroInflatedRegressor always tries to pass sample_weight, even if it's not given, we'll need to add some kind of check for it. Is this something you'd be interested to make a PR for?