scikit-learn / scikit-learn

scikit-learn: machine learning in Python
https://scikit-learn.org
BSD 3-Clause "New" or "Revised" License
58.84k stars 25.13k forks source link

Better (__-free) ways to specify grid search hyperparameters #19045

Open jnothman opened 3 years ago

jnothman commented 3 years ago

Describe the workflow you want to enable

I developed searchgrid to deal with one of my pet peeves in scikit-learn: that when performing a hyperparam search you need to specify complicated __-delimited paths to parameters that reflect a composite estimator structure defined elsewhere.

I think searchgrid does the right thing by allowing you to specify the parameter grids locally with respect to the estimator in which each parameter is changed.

This benefits in a few ways:

Problems with searchgrid include:

I regularly – at least when I work with someone else's code – feel the __-based parameter specification is a usability frustration always there when building and modifying composite estimators.

Describe your proposed solution

Either:

  1. Adopt something very similar to searchgrid's interface, allowing a parameter grid or distribution to be specified on each constituent estimator. param_grid is then not needed when constructing GridSearchCV, since the grid is found on the base estimator.
      kbest = set_grid(SelectKBest(), k=[5, 10, 20])
      pca = set_grid(PCA(), n_components=[5, 10, 20])
      lr = set_grid(LogisticRegression(), C=[.1, 1, 10])
      pipe = set_grid(Pipeline([('reduce', None),
                                ('clf', lr)]),
                      reduce=[kbest, pca])
      gs = make_grid_search(pipe)
  2. Allow an alternative specification of parameter grid / distribution, as a mapping of {(est, param_name): [value1, value2, ...], ...}, which is equivalent to what's used in searchgrid, but doesn't set any attributes on the estimators themselves, and puts the grid specification into the grid search rather than when constructing the composite estimator.
      kbest = SelectKBest()
      pca = PCA()
      lr = LogisticRegression()
      pipe = Pipeline([('reduce', None),
                       ('clf', lr)])
      gs = GridSearchCV(pipe, param_spaces={
          (pipe, 'reduce'): [kbest, pca],
          (kbest, 'k'): [5, 10, 20],
          (pca, 'n_components'): [5, 10, 20],
          (lr, 'C'): [.1, 1, 10],
      })

Additional context

NicolasHug commented 3 years ago

Not eliminatory, but a slight inconvenience of solution 2 is that it requires to define an instance for every estimator that needs to be tuned. For example

Pipeline([('reduce', None),
          ('clf', LogisticRegression())])   # instead of lr

isn't possible since we need the lr instance in order to define the grid.

NicolasHug commented 3 years ago

I'm also wondering: since ultimately we need to call set_params(), and since it only understands the 'clf__C' syntax, how does GridSearchCV know how to map (lr, 'C') into 'clf__C'? It could scan the pipeline's steps and compare lr with the estimator at each step, but how does that work if lr is actually more deeply nested, e.g. as part of a meta-estimator?

jnothman commented 3 years ago

See the algorithm in searchgrid... It works, even for parameters to estimators that are only sometimes part of the Pipeline.

glemaitre commented 3 years ago

Did we have a similar discussion where @amueller advocated for an estimator to have a parameter grid/distribution directly? I cannot find the issue anymore (probably this is where GitHub discussion would be great).

jnothman commented 3 years ago

Might have been a sprint discussion??

amueller commented 3 years ago

do you mean #5004 maybe?

glemaitre commented 3 years ago

Nop it was an issue/PR with code :)

On Thu, 7 Jan 2021 at 23:51, Andreas Mueller notifications@github.com wrote:

do you mean #5004 https://github.com/scikit-learn/scikit-learn/issues/5004 maybe?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/scikit-learn/scikit-learn/issues/19045#issuecomment-756434908, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABY32P4TMBARYRODPFGAOVLSYY3FZANCNFSM4VC2L7IA .

-- Guillaume Lemaitre Scikit-learn @ Inria Foundation https://glemaitre.github.io/

jnothman commented 3 years ago

I should note that the idea of an estimator having a default/sensible distribution grid is somewhat orthogonal. Certainly, it should be less controversial than trying to automatically determine a sensible distribution of parameters without respect to a particular task or dataset. The focus here is only on making it easier to work with parameter search specifications.

jnothman commented 2 years ago

I'm thinking that the right way to do this without storing the grid on the estimator itself (but rather storing a mapping from estimator to its local search grid) is with something like a GridFactory:

from sklearn.model_selection import GridFactory

grid = GridFactory()

kbest = grid.set(SelectKBest(), k=[5, 10, 20])
pca = grid.set(PCA(), n_components=[5, 10, 20])
lr = grid.set(LogisticRegression(), C=[.1, 1, 10])
pipe = grid.set(Pipeline([('reduce', None),
                          ('clf', lr)]),
                reduce=[kbest, pca])
gs = GridSearchCV(pipe, grid)

GridSearchCV might be extended to do something like:

if callable(param_grid):
    param_grid = param_grid(estimator)

to handle the case where param_grid is a GridFactory, which when called (or we can give it a method name) returns the grid for an estimator.

This has nicer OOP encapsulation than the "store the grid on the estimator" approach used in searchgrid, and might be a little easier to extend to search distributions, and to facilitate interpreting param columns of cv_results_.

jnothman commented 2 years ago

Alternative to

gs = GridSearchCV(pipe, grid)

would be

gs = grid.make_search(pipe)
thomasjpfan commented 2 years ago

I like the idea of simplifying how we define search spaces. What would the workflow look like for building GridFactory when a pipeline is already defined? Would one need to rewrite the pipeline with the GridFactory API to define the search space?

jnothman commented 2 years ago

No, there's nothing special about the pipeline constructed. The GridFactory maps each estimator to its local parameter space. When the Pipeline's local space involves searching over alternative steps, then the parameter grid constructed by the factory includes the downstream parameter grids of those alternative steps.

searchgrid provides a shorthand to make Pipelines, but it's really only so the user doesn't need to name steps and set them to None by default. Maybe I can push some code later today.

thomasjpfan commented 2 years ago

I think a typical use case is to build up a pipeline first for a ML task. For example:

preprocessor = ColumnTransformer(
    transformers=[
        ("num", StandardScaler(), numeric_features),
        ("cat", OneHotEncoder(), categorical_features),
    ]
)
clf = Pipeline(
    steps=[("preprocessor", preprocessor), ("classifier", LogisticRegression())]
)

Now we want to use GridFactory + GridSearch. With the API proposed, we would need to rewrite the above Pipeline logic with GridFactory and then create a GridSearchCV from it. Something like this:

grid = GridFactory()

scalar = grid.set(StandardScaler(), with_std=[True, False])
encoder = grid.set(OneHotEncoder(), drop=[True, None])
preprocessor = ColumnTransformer([
    ("num", scalar, numeric_feature),
    ("cat", encoder, categorical_features)])

lr = grid.set(LogisticRegression(), C=[0.1, 1.0, 10.0])
pipe = grid.set(Pipeline([("preprocessor", preprocessor),
                          ("classifier", lr)]))

gs = GridSearchCV(pipe, grid)

Is this correct?

jnothman commented 2 years ago

Yes, though there's obviously no need for the additional variable names, so this is fine too:

grid = GridFactory()

preprocessor = ColumnTransformer(
    transformers=[
        ("num", grid.set(StandardScaler(), with_std=[True, False]), numeric_features),
        ("cat", grid.set(OneHotEncoder(), drop=[True, False]), categorical_features),
    ]
)
clf = Pipeline(
    steps=[("preprocessor", preprocessor),
           ("classifier", grid.set(LogisticRegression(), C=[0.1, 1.0, 10.0]))]
)

The main point is to avoid having to write out:

param_grid = {
    "preprocessor__num__with_std": [True, False],
    "preprocessor__cat__drop": [True, False],
    "classifier__C": [0.1, 1.0, 10.0],
}

which only gets messier if:

thomasjpfan commented 2 years ago

With the following pipeline definition:

clf = Pipeline(
    steps=[("preprocessor", preprocessor),
           ("classifier", grid.set(LogisticRegression(), C=[0.1, 1.0, 10.0]))]
)

I suspect grid.set will return LogisticRegression(), so clf.fit would work. Do you think clf.fit(X, y) should work or raise an error?

Edit: I just saw that https://github.com/scikit-learn/scikit-learn/pull/21784 is running in another direction, so maybe we are not going with the GridFactory API.

jnothman commented 2 years ago

yes, why shouldn't fit work?

thomasjpfan commented 2 years ago

It came down to what grid.set(LogisticRegression(), C=[0.1, 1.0, 10.0]) returns. It could be: LogisticRegression() or LogisticRegression(C=0.1). From a UX point of view, clf.fit should just work.

I think grid.set should return LogisticRegression() because the API more naturally extends to RandomSearchCV. For example, the following returns LogisticRegression() without sampling C:

grid.set(LogisticRegression(), C=uniform(loc=0, scale=4))

Extending this idea to https://github.com/scikit-learn/scikit-learn/pull/21784#issuecomment-979284821 :

LogisticRegression().with_search_space(C=uniform(loc=0, scale=4))

would return LogisticRegression() and no sampling is performed.