scikit-learn-contrib / imbalanced-learn

A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning
https://imbalanced-learn.org
MIT License
6.85k stars 1.29k forks source link

Embedded Pipelines Raise AttributeError on .fit() #162

Closed bmritz closed 8 years ago

bmritz commented 8 years ago

If I create a "hierarchical pipeline" ( a pipeline where one step is another pipeline), then the Pipeline will raise an AttributeError on .fit() because it reads the imblearn.pipeline.Pipeline object as having a .fit_transform() attribute, and thus sending it to .fit_transform() where it tries to call .fit_transform() on an imblearn object within one of the steps.

The following will reproduce the error:

from imblearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from sklearn.datasets import make_classification
from sklearn import preprocessing
from sklearn.neighbors import KNeighborsClassifier

X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
    n_informative=3, n_redundant=1, flip_y=0,
    n_features=20, n_clusters_per_class=1,
    n_samples=1000, random_state=10)

Pipeline(steps=[
('(std->smt)', Pipeline(steps=[
    ('std', preprocessing.StandardScaler(copy=True, with_mean=True, with_std=True)), 
    ('smt', SMOTE(k=5, kind='regular', m=10, n_jobs=-1, out_step=0.5, random_state=None,     ratio='auto'))
    ])), 
('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform'))
]).fit(X,y)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-c79bf2b79986> in <module>()
      7            metric_params=None, n_jobs=1, n_neighbors=5, p=2,
      8            weights='uniform'))
----> 9 ]).fit(X,y)

/mnt/home/britz/.virtualenvs/snoop/lib/python2.7/site-packages/imblearn/pipeline.py in fit(self, X, y, **fit_params)
    113             the pipeline.
    114         """
--> 115         Xt, yt, fit_params = self._pre_transform(X, y, **fit_params)
    116         self.steps[-1][-1].fit(Xt, yt, **fit_params)
    117         return self

/mnt/home/britz/.virtualenvs/snoop/lib/python2.7/site-packages/imblearn/pipeline.py in _pre_transform(self, X, y, **fit_params)
     92         for name, transform in self.steps[:-1]:
     93             if hasattr(transform, "fit_transform"):
---> 94                 Xt = transform.fit_transform(Xt, yt, **fit_params_steps[name])
     95             elif hasattr(transform, "fit_sample"):
     96                 Xt, yt = transform.fit_sample(Xt, yt, **fit_params_steps[name])

/mnt/home/britz/.virtualenvs/snoop/lib/python2.7/site-packages/imblearn/pipeline.py in fit_transform(self, X, y, **fit_params)
    136             return self.steps[-1][-1].fit_transform(Xt, yt, **fit_params)
    137         else:
--> 138             return self.steps[-1][-1].fit(Xt, yt, **fit_params).transform(Xt)
    139 
    140     @if_delegate_has_method(delegate='_final_estimator')

AttributeError: 'SMOTE' object has no attribute 'transform'

FYI the .predict() method on the pipeline also raises an exception with embedded pipelines because it passes over the pipeline step without transforming because it sees that the Pipeline step has a fit_sample attribute.

glemaitre commented 8 years ago

@chkoar On the error itself, you should be able to have some more insights.

However, I don't see the utility of the example. Is it only a dummy example? Embedding the pipeline in another pipeline is equivalent to have a single linear one in the example that you gave.

glemaitre commented 8 years ago

@chkoar is the issue due to fact that a pipeline has both fit_transform and fit_sample during the pre_transform -> check there

bmritz commented 8 years ago

The actual code I posted is a toy example, but I do use embedded pipelines fairly regularly.

I usually use embedded pipelines why trying out different preprocessing schemes and different feature selections. For example, let's say I wanted to try out MinMaxScaler and Standard Scaler and no scaling on a subset of features. I'd set up a pipeline that uses preprocessing.FunctionTransformer to subset the columns. and then inside a for loop wrap that Pipeline into another pipeline with the second step being the scaler for that iteration of the loop. This way I can create many pipelines off of a "base" pipeline, and keep track of them fairly easily.

glemaitre commented 8 years ago

@bmritz That make sense. I was using list of the object an creating the pipeline inside loops for the same thing.

Thanks for reporting.

chkoar commented 8 years ago

It is more complicated than one may expects because we worked on the sklearn's Pipeline object to reuse linear transformations as @glemaitre mentioned.

We expect samplers or transformers in the Pipeline as it is stated in the docstring and we should warn the user for this.

Exchanging if with elif in L130 solves the problem in fit.

With the current design the samplers work in the training phase only. So actually the sample method is a training method. It requires the target y. @glemaitre with the old API where we call transform in samplers without parameters it would more easier in this case, I think.

@bmritz did you try to put the sampler in its own step outside of th nested Pipeline?

bmritz commented 8 years ago

Yes that is what I ended up doing -- Because I wanted to understand the effect of resampling on my final model, I resampled outside the pipeline and then creating two pipelines off of two training sets, one un-resampled and one resampled.

It makes sense that resamplers work only on training phase, for validation or test it there would be no need to resample, so I see your logic there.

On Thu, Oct 13, 2016 at 2:50 AM, chkoar notifications@github.com wrote:

It is more complicated than one may expects because we worked on the sklearn's Pipeline object to reuse linear transformations as @glemaitre https://github.com/glemaitre mentioned.

We expect samplers and transformers in the Pipeline as it is stated in the docstring https://github.com/scikit-learn-contrib/imbalanced-learn/blob/master/imblearn/pipeline.py#L33 and we should warn the user for this.

With the current design the samplers work in the training phase only. So actually the sample method is a training method. It requires the target y. @glemaitre https://github.com/glemaitre with the old API where we call transform in samplers without parameters it would more easier in this case, I think.

Did you try to put the sampler in its own step outside a Pipeline?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/scikit-learn-contrib/imbalanced-learn/issues/162#issuecomment-253428883, or mute the thread https://github.com/notifications/unsubscribe-auth/AHalG4K6M7dCszyWBeGzPPh3pgWee1rQks5qzdSygaJpZM4KU9TA .

Brian Ritz Data Scientist

(m) 219.808.4648 <630.965.4686>

Aunalytics • rethink data

Catalyst One @ Ignition Park 460 Stull St., Suite 100 South Bend, IN 46601

online aunalytics.com linkedin linkedin.com/company/aunalytics

chkoar commented 8 years ago

I resampled outside the pipeline

@bmritz if I am not missing something, in your case I would create an IdentityResampler that samples what it has. Then I would use grid searching and validation curves to see how the model performs by varying the resampler parameter of the pipeline

Pipeline(steps=[
    ('std', preprocessing.StandardScaler(copy=True, with_mean=True, with_std=True)), 
    ('resampler', SMOTE(k=5, kind='regular', m=10, n_jobs=-1, out_step=0.5, random_state=None, ratio='auto')),
    ('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform'))
]).fit(X,y)

@dvro, @glemaitre Do we need an IdentityResampler ?

glemaitre commented 8 years ago

That could make sense to add it for consistency.

chkoar commented 8 years ago

@bmritz like this vb

glemaitre commented 8 years ago

@chkoar Can you label this issue to know when we will address it if needed.

chkoar commented 8 years ago

@glemaitre done in #166