reiinakano / scikit-plot

An intuitive library to add plotting functionality to scikit-learn objects.
MIT License
2.43k stars 285 forks source link

How to stratify the data when using the classifier factory? #48

Closed rhiever closed 7 years ago

rhiever commented 7 years ago

Hello,

Thanks so much for your great work on scikit-plot. I've found it quite useful in my ML workflows.

I'm wondering: I work with imbalanced datasets pretty frequently, so it's important for me to be able to stratify my train/test splits. When I use the classifier factory to generate plots directly from the classifier object, I don't see any options to stratify my data (e.g. in the plot_confusion_matrix function). How can I accomplish this?

reiinakano commented 7 years ago

Hi @rhiever, thanks for your kind words!

By default, the folds used are stratified since the default CV method is stratified CV. You can change it by setting the cv argument.

    cv (int, cross-validation generator, iterable, optional): Determines the
        cross-validation strategy to be used for splitting.
        Possible inputs for cv are:
          - None, to use the default 3-fold cross-validation,
          - integer, to specify the number of folds.
          - An object to be used as a cross-validation generator.
          - An iterable yielding train/test splits.
        For integer/None inputs, if ``y`` is binary or multiclass,
        :class:`StratifiedKFold` used. If the estimator is not a classifier
        or if ``y`` is neither binary nor multiclass, :class:`KFold` is used.

As an aside, I've found that using the Functions API gives me a lot more flexibility than the Factory API (in fact, I'm considering deprecating it in the future).

rhiever commented 7 years ago

OK. I actually started using the Functions API today and may follow suit, as you suggested. Thanks again!