ray-project / tune-sklearn

A drop-in replacement for Scikit-Learn’s GridSearchCV / RandomizedSearchCV -- but with cutting edge hyperparameter tuning techniques.
https://docs.ray.io/en/master/tune/api_docs/sklearn.html
Apache License 2.0
465 stars 52 forks source link

`refit=True` is necessary to get `best_params_` #113

Closed Yard1 closed 3 years ago

Yard1 commented 4 years ago

If refit param is set to False, there is no way to get the best params.

richardliaw commented 4 years ago

Hmm, right - is there a recommendation here are to what we should do?

Yard1 commented 4 years ago

If scoring isn't multimetric they should be always available. That's how it's done in sklearn tuners. I can take a look at it myself when I have a moment.

Yard1 commented 4 years ago

This is the code in sklearn's BaseSearchCV responsible for this.

        # For multi-metric evaluation, store the best_index_, best_params_ and
        # best_score_ iff refit is one of the scorer names
        # In single metric evaluation, refit_metric is "score"
        if self.refit or not self.multimetric_:
            # If callable, refit is expected to return the index of the best
            # parameter set.
            if callable(self.refit):
                self.best_index_ = self.refit(results)
                if not isinstance(self.best_index_, numbers.Integral):
                    raise TypeError('best_index_ returned is not an integer')
                if (self.best_index_ < 0 or
                   self.best_index_ >= len(results["params"])):
                    raise IndexError('best_index_ index out of range')
            else:
                self.best_index_ = results["rank_test_%s"
                                           % refit_metric].argmin()
                self.best_score_ = results["mean_test_%s" % refit_metric][
                                           self.best_index_]
            self.best_params_ = results["params"][self.best_index_]

        if self.refit:
            # we clone again after setting params in case some
            # of the params are estimators as well.
            self.best_estimator_ = clone(clone(base_estimator).set_params(
                **self.best_params_))
            refit_start_time = time.time()
            if y is not None:
                self.best_estimator_.fit(X, y, **fit_params)
            else:
                self.best_estimator_.fit(X, **fit_params)
            refit_end_time = time.time()
            self.refit_time_ = refit_end_time - refit_start_time

Compare with the code in TuneBaseSearchCV:

        if self.refit:
            best_config = analysis.get_best_config(
                metric=metric, mode="max", scope="last")
            self.best_params = self._clean_config_dict(best_config)

            self.best_estimator_ = clone(self.estimator)
            if self.early_stop_type == EarlyStopping.WARM_START_ENSEMBLE:
                logger.info("tune-sklearn uses `n_estimators` to warm "
                            "start, so this parameter can't be "
                            "set when warm start early stopping. "
                            "`n_estimators` defaults to `max_iters`.")
                if check_is_pipeline(self.estimator):
                    cloned_base_estimator = self.best_estimator_.steps[-1][1]
                    cloned_base_estimator.set_params(
                        **{"n_estimators": self.max_iters})
                else:
                    self.best_params["n_estimators"] = self.max_iters
            self.best_estimator_.set_params(**self.best_params)
            self.best_estimator_.fit(X, y, **fit_params)

            best_result = analysis.get_best_trial(
                metric=metric, mode="max", scope="last").last_result
            self.best_score = float(best_result[metric])

Considering one of the recent PRs has fixed multimetric scoring there should be no issues with changing the functionality of TuneBaseSearchCV to match sklearn. I will get it done in the coming days.

richardliaw commented 4 years ago

Yep that sounds good!

Yard1 commented 4 years ago

WIP here: https://github.com/ray-project/tune-sklearn/pull/114

Need to fix CI and add some new tests

richardliaw commented 3 years ago

This is closed by #114