Closed leonardtschora closed 1 year ago
Sorry, I guess this one fell under the radar.
Just skimmed your comment but here's a quick reply, which hopefully addressed your point:
In general, because one is resampling to get performance estimates for each model
(set of hyperparameters) you can't make this "intelligent" except in the special case that resampling isa Holdout
(and no randomisation), eg resampling=Holdout()
or, in the case you have, that resampling
consists of a single test/train pair. However, update
for Resampler
model wrapper (a private object) is only overloaded for Holdout
and not your special case, and hence is slow. I guess one could overload for your case also, but this is probably not a big use-case. PR welcome.
Try your benchmarks with resampling=Holdout()
and see if you get an improvement.
Does that make sense?
Hi, thanks for your reply.
I have tried to use the Holdout
resampling and it yields the expected results: 80ms, showing that it performs intelligent refiting.
Here is the code:
self_tuning_forest_model_holdout = TunedModel(model=forest_model,
tuning=Grid(shuffle=false),
resampling=Holdout(; fraction_train=0.99, shuffle=false),
range=r,
measure=rms);
self_tuning_forest_holdout = machine(self_tuning_forest_model_holdout, X, y);
fit!(self_tuning_forest_holdout, verbosity=1)
@btime begin
self_tuning_forest_holdout = machine(self_tuning_forest_model_holdout, X, y);
fit!(self_tuning_forest_holdout, verbosity=0)
end
I think the strategy here would be to iterate first on the different (train, test)
datasets and then on hyper-parameters (this is an example from one of my grid search):
train_test_pairs = train_test_pairs(my_sampler, X, y)
Threads.@threads for k in collect(1:n_cv)
(train_indices, val_indices) = train_test_pairs[k]
train_set = selectrows(X, train_indices)
train_labels = selectrows(y, train_indices)
val_set = selectrows(X, val_indices)
val_labels = selectrows(y, val_indices)
model = MyModel()
mach = machine(model, train_set, train_labels)
for every hyper-parameter configuration to try
mutate the model's attributes
update the machine mach
compute the error on the validation set
store the error
end
end
Then, all you have to do is compute the average error across all datasets. It worked well fro my use case. Let me know if you have any updates on this subject, I will try to spare time and make a proper implementation of this.
Hi everyone,
While benchmarking some toy grid searches, I obtained odd results, and it seemed to me that performing a grid search using a
TunedModel
is slower than it should be.The idea is to run a grid search over a model that implements the
update
method, and avoid re-fitting models from scratch for each sampled hyper-parameter set. More precisly, by arranging the grid search so it only changes 1 hyper-parameter per iteration.Here is a sample code on a toy problem, using
EnsembleModel
andDecisionTree
. The idea is to play with the number of estimators of the Ensemble, and find the optimal one. While a naïve approach would be to restart training form scratch for each new number of estimators, a smarter approach would be to start at the lowest number, and add 1 estimator at each iteration. The updating cost of the ensemble model is then very low (only 1 new estimator to fit) and we expect the Grid search to be much faster.Solving this problem using the MLJ interface:
Then, I have implemented 2 manual grid searches. The first is not intelligent and will restart from scratch, the second will only mutate the
n_estimator
field of theEnsembleModel
and update the associated machine.Given those results, it seems to me that the Grid Search using a
TunedModel
is just performing a naïve search by retraining every new model from scratch, instead of re-fitting them. We can also see that we can improve the speed of the grid search by a factor of 10 on this toy example.I started delving into the implementation details, and found that the problem was not coming form the
Grid
implementation. TheGrid
creates a list of models to train by cloning and mutating them, but if we mutate the model field of a machine and set it to a new one, the machine should still update itself as in this example:Then I started looking at the
TunedModel
code, but things are becoming much more complicated and I'm afraid I would not be able to understand it alone.As always, thanks for the time and support you provide me.