Closed thomasj02 closed 3 years ago
hi @thomasj02, I think your point is discussed in #1447.
I think #1447 is a different issue. That bug deals with problems around pruning jittery loss functions. The issue I'm describing occurs even with perfectly smooth loss functions. Let me try to illustrate with an example. Suppose you're using the MedianPruner with n_startup_trials is 5, and you have 5 trials completed. Due to early stopping, 4 of the trials end at epoch 10, and one trials (trial 5) ends at epoch 15.
Now you're running trial 6, and you're on epoch 11. Suppose trial 6 is just a tiny bit worse than trial 5 at epoch 11. At epoch 11, is trial 6 worse than the median? Yes, so it should be pruned. But this is super aggressive, because it's worse than just one other run (trial 5) at one particular epoch (epoch 11).
What I'm suggesting is that when using the percentile pruner or median pruner, there should be at least n_startup_trials samples at a particular epoch before pruning is considered.
Thank you for your explanations. Could you give a reproducible code for your bug?
In my understanding, you can avoid the issue by not performing early stoping for n_startup_trials
.
I've put a repro here: https://gist.github.com/thomasj02/b2496c881aaa3731c6b5aaf7e5bc76fb
You're correct that the issue could initially be avoided by not performing early stopping for n_startup_trials, if all trials have a the same maximum number of epochs. But a common use case is training a network until convergence (i.e. with no defined number of maximum epochs).
Edit: Also the maximum number of epochs could just be very large and take an extremely long time, and the researcher expects that most trials will terminate well before the max epochs due to early stopping.
Thank you for sharing your code and clear explanations. I understood your feature request.
That can be implemented by adding a new argument of PercentilePruner
like n_samples_to_prune
to the pruner classes since MedianPruner
inherits PercentilePruner
. Then pruner's prune
returns False
if the number of evaluated trials is less than n_samples_to_prune
at every step.
Modifying n_startup_trials
's logic breaks backwards-compatibility, so I think adding a new parameter can avoid the issue.
Additional context: Ray tune has such an argument, min_samples_required
. See https://docs.ray.io/en/master/tune/api_docs/schedulers.html#median-stopping-rule-tune-schedulers-medianstoppingrule.
Dear @HideakiImamura @hvy @toshihikoyanase, what do you think about this feature request?
@thomasj02 I implemented the feature as in https://github.com/nzw0301/optuna/tree/add-n-min-samples (more specifically, https://github.com/nzw0301/optuna/commit/0f94b1c9dbfdc43793e19d01ec5576638d176dc4). If I missed something for the feature request, please let me know!
import optuna
from typing import List
class ObjectiveStub():
def __init__(self, trial_losses: List[List[float]]):
self.trial_losses = trial_losses
self.trial_num = 0
def __call__(self, trial: optuna.Trial):
for epoch, loss in enumerate(self.trial_losses[self.trial_num]):
trial.report(loss, epoch)
should_prune = trial.should_prune()
print(f"Trial: {self.trial_num} epoch: {epoch} loss: {loss} should_prune: {should_prune}")
if should_prune:
print(f"Pruning triggered at trial {self.trial_num} epoch {epoch} with loss {loss}")
raise optuna.TrialPruned()
retval = self.trial_losses[self.trial_num][-1]
self.trial_num += 1
return retval
objective = ObjectiveStub(trial_losses=[
[5, 4, 3],
[5, 4, 3, 2],
[5, 4, 3, 2.1],
[5, 4, 3, 2.2, 1.0] # add a new trial for this example
])
# `n_min_trials` is an argument that we need.
# `n_min_trials=1` behaves like the current optuna.
pruner = optuna.pruners.MedianPruner(n_startup_trials=2, n_min_trials=2)
study = optuna.create_study(
direction="minimize",
pruner=pruner)
study.optimize(objective, n_trials=4)
[I 2021-03-12 00:17:52,156] A new study created in memory with name: no-name-3faa476c-15cc-46a8-8928-94d6597b0a37
[I 2021-03-12 00:17:52,159] Trial 0 finished with value: 3.0 and parameters: {}. Best is trial 0 with value: 3.0.
[I 2021-03-12 00:17:52,160] Trial 1 finished with value: 2.0 and parameters: {}. Best is trial 1 with value: 2.0.
[I 2021-03-12 00:17:52,162] Trial 2 finished with value: 2.1 and parameters: {}. Best is trial 1 with value: 2.0.
[I 2021-03-12 00:17:52,164] Trial 3 pruned.
Trial: 0 epoch: 0 loss: 5 should_prune: False
Trial: 0 epoch: 1 loss: 4 should_prune: False
Trial: 0 epoch: 2 loss: 3 should_prune: False
Trial: 1 epoch: 0 loss: 5 should_prune: False
Trial: 1 epoch: 1 loss: 4 should_prune: False
Trial: 1 epoch: 2 loss: 3 should_prune: False
Trial: 1 epoch: 3 loss: 2 should_prune: False
Trial: 2 epoch: 0 loss: 5 should_prune: False
Trial: 2 epoch: 1 loss: 4 should_prune: False
Trial: 2 epoch: 2 loss: 3 should_prune: False
Trial: 2 epoch: 3 loss: 2.1 should_prune: False # it was pruned
Trial: 3 epoch: 0 loss: 5 should_prune: False
Trial: 3 epoch: 1 loss: 4 should_prune: False
Trial: 3 epoch: 2 loss: 3 should_prune: False
Trial: 3 epoch: 3 loss: 2.2 should_prune: True
Pruning triggered at trial 3 epoch 3 with loss 2.2
I looked at the code and the example and both look great. Thanks for the quick implementation!
What is the release cycle like for Optuna, i.e. when do you expect this to be available in a release?
@thomasj02 Thank you for your quick check and feedback! According to this GitHub page, optuna releases a new version in a month. However, the implementation is only in my branch and I haven't sent a pull request to optuna yet because I'm not sure that this feature will be merged by dev teams. Thus it can take more time...
This issue has not seen any recent activity.
Hi @thomasj02, the feature has been merged into the main branch of optuna. The next stable release will be within one month!
This issue has not seen any recent activity.
Let me close this issue. Feel free to reopen if you have the problem.
Thank you for this feature. I use early stopping when training and didn't realize the default behavior.
On a related note, if using train or validation loss, it seems like median pruning would favor models that converge quicker and results would be confounded by differences in convergence rate. That is, a model that converges slower may have a higher validation loss at an earlier step compared to other models but may ultimately produce a lower objective value. It seems like median pruning is biased to identify models that learn faster. Is that correct? I noticed this effect when including a graph normalization step in my GNN. Graph normalization is used to speed up learning and potentially produce better models. However, the models without graph normalization were pruned more frequently. To try and fix this issue I've added patience and validation loss, as opposed to no patience and training loss. However, I think the issue remains especially when changing parameters that affect convergence. I may be off base though, thoughts?
I suppose so since the median pruning rule implicitly assumes what you described. The n_min_trials
option might not resolve the problem perfectly. However, this sounds like a property of the median pruning rule. Do you know any implementation or algorithms to resolve this problem?
I'm not familiar with any but will try to experiment with a check that looks at step-to-step change in model weights and loss.
I'm using the MedianPruner (although I think the PercentilePruner has the same issue). When there are more than n_startup_trials trials, but only one of them is long-running (e.g., most trials end at epoch 10, but one trial ends at epoch 100), then it looks like the MedianPruner will trigger pruning on epoch 11 on a new trial if the new trial is worse than the long-running trial. This makes the MedianPruner overly aggressive in terms of triggering pruning.
It would be better (and more intuitive) if the pruner only triggered pruning if there were at least n_startup_trials samples at the epoch being evaluated (or introduce a new parameter / rename n_startup_trials if you think that's more appropriate).
Environment