ray-project / tune-sklearn

A drop-in replacement for Scikit-Learn’s GridSearchCV / RandomizedSearchCV -- but with cutting edge hyperparameter tuning techniques.
https://docs.ray.io/en/master/tune/api_docs/sklearn.html
Apache License 2.0
467 stars 51 forks source link

Resuming from checkpoint? #260

Open jonathanmiller2 opened 1 year ago

jonathanmiller2 commented 1 year ago

There doesn't seem to be any documentation for how to resume the train from the local_dir checkpoints. Using the following provided example:

"""Example using an sklearn Pipeline with TuneGridSearchCV.
Example taken and modified from
https://scikit-learn.org/stable/auto_examples/compose/
plot_compare_reduction.html
"""

from tune_sklearn import TuneSearchCV
from tune_sklearn import TuneGridSearchCV
from sklearn.datasets import load_digits
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA, NMF
from sklearn.feature_selection import SelectKBest, chi2

pipe = Pipeline([
    # the reduce_dim stage is populated by the param_grid
    ("reduce_dim", "passthrough"),
    ("classify", LinearSVC(dual=False, max_iter=10000))
])

N_FEATURES_OPTIONS = [2, 4, 8]
C_OPTIONS = [1, 10]
param_grid = [
    {
        "reduce_dim": [PCA(iterated_power=7), NMF()],
        "reduce_dim__n_components": N_FEATURES_OPTIONS,
        "classify__C": C_OPTIONS
    },
    {
        "reduce_dim": [SelectKBest(chi2)],
        "reduce_dim__k": N_FEATURES_OPTIONS,
        "classify__C": C_OPTIONS
    },
]

random = TuneSearchCV(pipe, param_grid, search_optimization="random", local_dir='checkpoints')
X, y = load_digits(return_X_y=True)
random.fit(X[:100], y[:100])
print(random.cv_results_)

Trying to load the generated checkpoint using normal Ray methods such as:

from ray.tune import Tuner
Tuner.restore('checkpoints/_Trainable_2022-12-28_11-47-41')

Yields:

RuntimeError: Could not find Tuner state in restore directory. Did you passthe correct path (including experiment directory?) Got: checkpoints/_Trainable_2022-12-28_11-47-41

What is the intended way of loading checkpoints?