ray-project / tune-sklearn

A drop-in replacement for Scikit-Learn’s GridSearchCV / RandomizedSearchCV -- but with cutting edge hyperparameter tuning techniques.
https://docs.ray.io/en/master/tune/api_docs/sklearn.html
Apache License 2.0
465 stars 52 forks source link

[feature] Save all model checkpoints in TuneSearchCV? #148

Open sungreong opened 3 years ago

sungreong commented 3 years ago

Hi, I have a question.

When using TuneSearchCV, can you save all the models with the highest performance for each trial?

richardliaw commented 3 years ago

Can you provide an example of what you mean?

sungreong commented 3 years ago

In Example,

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
from tune_sklearn import TuneGridSearchCV

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()
        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X))
        return X

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

params = {
    "lr": [0.01, 0.02],
    "module__num_units": [10, 20],
}

gs = TuneGridSearchCV(net, params, scoring="accuracy")
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

I do not want to save only the best case of gs, but I want to save all the models from all the results of the experiment. for example , I want to save the best performance value for each experiment in (lr, num_units) (0.01, 10) , (0.01, 20) , (0.02, 10) , (0.02, 20)

richardliaw commented 3 years ago

Hey @sungreong we actually removed this feature but am happy to reintroduce support. I'll push a PR soon!