TeamHG-Memex / eli5

A library for debugging/inspecting machine learning classifiers and explaining their predictions
http://eli5.readthedocs.io
MIT License
2.75k stars 331 forks source link

Permutation importance with sample weights and CV folds fails #358

Closed rg2410 closed 4 years ago

rg2410 commented 4 years ago

Hi,

I'm using permutation importance with a sklearn Random Forest classifier where I pass sample weights and cross-validation folds. I'm getting an error in the _cv_scores_importances method in permutation_importance.py which I believe is caused because the sample weights are not sliced like the predictors and target.

Here's a code to illustrate this:

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from eli5.sklearn import PermutationImportance
import numpy as np

data = load_breast_cancer()
X, y = data.data, data.target
kf = KFold(n_splits=5).get_n_splits(X)
weights = np.random.rand(len(y))
model = RandomForestClassifier()

perm = PermutationImportance(model, cv=kf).fit(X, y, sample_weight=weights)

The error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-67bb1e3b5a62> in <module>
     11 model = RandomForestClassifier()
     12 
---> 13 perm = PermutationImportance(model, cv=kf).fit(X, y, sample_weight=weights)

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\eli5\sklearn\permutation_importance.py in fit(self, X, y, groups, **fit_params)
    198 
    199         if self.cv not in (None, "prefit"):
--> 200             si = self._cv_scores_importances(X, y, groups=groups, **fit_params)
    201         else:
    202             si = self._non_cv_scores_importances(X, y)

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\eli5\sklearn\permutation_importance.py in _cv_scores_importances(self, X, y, groups, **fit_params)
    214         base_scores = []  # type: List[float]
    215         for train, test in cv.split(X, y, groups):
--> 216             est = clone(self.estimator).fit(X[train], y[train], **fit_params)
    217             score_func = partial(self.scorer_, est)
    218             _base_score, _importances = self._get_score_importances(

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\sklearn\ensemble\forest.py in fit(self, X, y, sample_weight)
    328                     t, self, X, y, sample_weight, i, len(trees),
    329                     verbose=self.verbose, class_weight=self.class_weight)
--> 330                 for i, t in enumerate(trees))
    331 
    332             # Collect newly grown trees

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\parallel.py in __call__(self, iterable)
   1002             # remaining jobs.
   1003             self._iterating = False
-> 1004             if self.dispatch_one_batch(iterator):
   1005                 self._iterating = self._original_iterator is not None
   1006 

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\parallel.py in dispatch_one_batch(self, iterator)
    833                 return False
    834             else:
--> 835                 self._dispatch(tasks)
    836                 return True
    837 

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\parallel.py in _dispatch(self, batch)
    752         with self._lock:
    753             job_idx = len(self._jobs)
--> 754             job = self._backend.apply_async(batch, callback=cb)
    755             # A job can complete so quickly than its callback is
    756             # called before we get here, causing self._jobs to

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\_parallel_backends.py in apply_async(self, func, callback)
    207     def apply_async(self, func, callback=None):
    208         """Schedule a func to be run"""
--> 209         result = ImmediateResult(func)
    210         if callback:
    211             callback(result)

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\_parallel_backends.py in __init__(self, batch)
    588         # Don't delay the application, to avoid keeping the input
    589         # arguments in memory
--> 590         self.results = batch()
    591 
    592     def get(self):

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\parallel.py in __call__(self)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\joblib\parallel.py in <listcomp>(.0)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

~\AppData\Local\Continuum\anaconda3\envs\eli5_perm\lib\site-packages\sklearn\ensemble\forest.py in _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, verbose, class_weight)
    107         indices = _generate_sample_indices(tree.random_state, n_samples)
    108         sample_counts = np.bincount(indices, minlength=n_samples)
--> 109         curr_sample_weight *= sample_counts
    110 
    111         if class_weight == 'subsample':

ValueError: operands could not be broadcast together with shapes (569,) (454,) (569,) 

The versions of the packages I'm using:

eli5==0.10.1
pandas==0.25.3
scikit-learn==0.21.1

Where I think the problem is: https://github.com/TeamHG-Memex/eli5/blob/4839d1927c4a68aeff051935d1d4d8a4fb69b46d/eli5/sklearn/permutation_importance.py#L217-L218

Please let me know if you find the same issue and, if so, this might be able to be fixed quickly.

Many thanks

lopuhin commented 4 years ago

Fixed by https://github.com/TeamHG-Memex/eli5/pull/359, thanks @rg2410