ray-project / tune-sklearn

A drop-in replacement for Scikit-Learn’s GridSearchCV / RandomizedSearchCV -- but with cutting edge hyperparameter tuning techniques.
https://docs.ray.io/en/master/tune/api_docs/sklearn.html
Apache License 2.0
465 stars 52 forks source link

For evaluating multiple scores, use sklearn.model_selection.cross_validate instead #53

Closed sumanthratna closed 4 years ago

sumanthratna commented 4 years ago

Reproducible Example

from tune_sklearn import TuneSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from scipy.stats import randint
import numpy as np

digits = datasets.load_digits()
x = digits.data
y = digits.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.2)

clf = RandomForestClassifier(random_state=317, verbose=100)
param_distributions = {
    "n_estimators": (1, 120),
}

tune_search = TuneSearchCV(
    clf,
    param_distributions,
    scoring=(
        'homogeneity_score',
        'completeness_score',
    ),
    verbose=2,
    search_optimization='bayesian'
)

tune_search.fit(x_train, x_test)

pred = tune_search.predict(y_train)
accuracy = np.count_nonzero(
    np.array(pred) == np.array(y_test)) / len(pred)
print(accuracy)

results in:

/private/tmp/venv/lib/python3.8/site-packages/tune_sklearn/tune_basesearch.py:245: UserWarning: Early stopping is not enabled. To enable early stopping, pass in a supported scheduler from Tune and ensure the estimator has `partial_fit`.
  warnings.warn("Early stopping is not enabled. "
Redis failed to start, retrying now.
Traceback (most recent call last):
  File "pls.py", line 38, in <module>
    tune_search.fit(x_train, x_test)
  File "/private/tmp/venv/lib/python3.8/site-packages/tune_sklearn/tune_basesearch.py", line 368, in fit
    result = self._fit(X, y, groups, **fit_params)
  File "/private/tmp/venv/lib/python3.8/site-packages/tune_sklearn/tune_basesearch.py", line 290, in _fit
    self.scoring = check_scoring(self.estimator, scoring=self.scoring)
  File "/private/tmp/venv/lib/python3.8/site-packages/sklearn/utils/validation.py", line 73, in inner_f
    return f(**kwargs)
  File "/private/tmp/venv/lib/python3.8/site-packages/sklearn/metrics/_scorer.py", line 430, in check_scoring
    raise ValueError("For evaluating multiple scores, use "
ValueError: For evaluating multiple scores, use sklearn.model_selection.cross_validate instead. ('homogeneity_score', 'completeness_score', 'v_measure_score', 'adjusted_rand_score', 'adjusted_mutual_info_score') was passed.

Environment

richardliaw commented 4 years ago

BTW @sumanthratna does this happen if you swap out with RandomizedSearchCV?

inventormc commented 4 years ago

Hi @sumanthratna, we don't currently support multimetric scoring, but we're looking into adding this functionality! For now, this should work if you just put one scorer.

sumanthratna commented 4 years ago

BTW @sumanthratna does this happen if you swap out with RandomizedSearchCV?

No, there's no error.

from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from scipy.stats import randint
import numpy as np

digits = datasets.load_digits()
x = digits.data
y = digits.target

clf = RandomForestClassifier(random_state=317, verbose=100)
param_distributions = {
    "n_estimators": (1, 120),
}

mysearch = RandomizedSearchCV(
    clf,
    param_distributions,
    scoring=(
        'homogeneity_score',
        'completeness_score',
    ),
    verbose=2,
    refit=False,
)

mysearch.fit(x, y)

Hi @sumanthratna, we don't currently support multimetric scoring, but we're looking into adding this functionality! For now, this should work if you just put one scorer.

It looks like the docstring for the scoring argument in TuneSearchCV was copied from that of in RandomizedSearchCV

https://github.com/ray-project/tune-sklearn/blob/5e8db24724b99c16b660dbdb0f3f54bfa1ecfb81/tune_sklearn/tune_search.py#L83-L92

It might help to remove the reference to multiple metrics. I can open a PR for this if that'd help.

inventormc commented 4 years ago

You're more than welcome to open a PR. We'll be happy to take a look :)