optuna / optuna

A hyperparameter optimization framework
https://optuna.org
MIT License
10.91k stars 1.04k forks source link

Pruner is too aggressive when trials have different numbers of steps #2441

Closed thomasj02 closed 3 years ago

thomasj02 commented 3 years ago

I'm using the MedianPruner (although I think the PercentilePruner has the same issue). When there are more than n_startup_trials trials, but only one of them is long-running (e.g., most trials end at epoch 10, but one trial ends at epoch 100), then it looks like the MedianPruner will trigger pruning on epoch 11 on a new trial if the new trial is worse than the long-running trial. This makes the MedianPruner overly aggressive in terms of triggering pruning.

It would be better (and more intuitive) if the pruner only triggered pruning if there were at least n_startup_trials samples at the epoch being evaluated (or introduce a new parameter / rename n_startup_trials if you think that's more appropriate).

Environment

nzw0301 commented 3 years ago

hi @thomasj02, I think your point is discussed in #1447.

thomasj02 commented 3 years ago

I think #1447 is a different issue. That bug deals with problems around pruning jittery loss functions. The issue I'm describing occurs even with perfectly smooth loss functions. Let me try to illustrate with an example. Suppose you're using the MedianPruner with n_startup_trials is 5, and you have 5 trials completed. Due to early stopping, 4 of the trials end at epoch 10, and one trials (trial 5) ends at epoch 15.

Now you're running trial 6, and you're on epoch 11. Suppose trial 6 is just a tiny bit worse than trial 5 at epoch 11. At epoch 11, is trial 6 worse than the median? Yes, so it should be pruned. But this is super aggressive, because it's worse than just one other run (trial 5) at one particular epoch (epoch 11).

What I'm suggesting is that when using the percentile pruner or median pruner, there should be at least n_startup_trials samples at a particular epoch before pruning is considered.

nzw0301 commented 3 years ago

Thank you for your explanations. Could you give a reproducible code for your bug?

nzw0301 commented 3 years ago

In my understanding, you can avoid the issue by not performing early stoping for n_startup_trials.

thomasj02 commented 3 years ago

I've put a repro here: https://gist.github.com/thomasj02/b2496c881aaa3731c6b5aaf7e5bc76fb

You're correct that the issue could initially be avoided by not performing early stopping for n_startup_trials, if all trials have a the same maximum number of epochs. But a common use case is training a network until convergence (i.e. with no defined number of maximum epochs).

Edit: Also the maximum number of epochs could just be very large and take an extremely long time, and the researcher expects that most trials will terminate well before the max epochs due to early stopping.

nzw0301 commented 3 years ago

Thank you for sharing your code and clear explanations. I understood your feature request.

That can be implemented by adding a new argument of PercentilePruner like n_samples_to_prune to the pruner classes since MedianPruner inherits PercentilePruner. Then pruner's prune returns False if the number of evaluated trials is less than n_samples_to_prune at every step.

Modifying n_startup_trials's logic breaks backwards-compatibility, so I think adding a new parameter can avoid the issue.

nzw0301 commented 3 years ago

Additional context: Ray tune has such an argument, min_samples_required. See https://docs.ray.io/en/master/tune/api_docs/schedulers.html#median-stopping-rule-tune-schedulers-medianstoppingrule.

nzw0301 commented 3 years ago

Dear @HideakiImamura @hvy @toshihikoyanase, what do you think about this feature request?

nzw0301 commented 3 years ago

@thomasj02 I implemented the feature as in https://github.com/nzw0301/optuna/tree/add-n-min-samples (more specifically, https://github.com/nzw0301/optuna/commit/0f94b1c9dbfdc43793e19d01ec5576638d176dc4). If I missed something for the feature request, please let me know!

Example

import optuna
from typing import List

class ObjectiveStub():
  def __init__(self, trial_losses: List[List[float]]):
    self.trial_losses = trial_losses
    self.trial_num = 0

  def __call__(self, trial: optuna.Trial):
    for epoch, loss in enumerate(self.trial_losses[self.trial_num]):
      trial.report(loss, epoch)
      should_prune = trial.should_prune()
      print(f"Trial: {self.trial_num} epoch: {epoch} loss: {loss} should_prune: {should_prune}")
      if should_prune:
        print(f"Pruning triggered at trial {self.trial_num} epoch {epoch} with loss {loss}")
        raise optuna.TrialPruned()

    retval = self.trial_losses[self.trial_num][-1]
    self.trial_num += 1
    return retval

objective = ObjectiveStub(trial_losses=[
                                  [5, 4, 3], 
                                  [5, 4, 3, 2], 
                                  [5, 4, 3, 2.1],
                                  [5, 4, 3, 2.2, 1.0]  # add a new trial for this example
                                  ])

# `n_min_trials` is an argument that we need.
# `n_min_trials=1` behaves like the current optuna.
pruner = optuna.pruners.MedianPruner(n_startup_trials=2, n_min_trials=2)
study = optuna.create_study(
        direction="minimize",
        pruner=pruner)

study.optimize(objective, n_trials=4)

Output

[I 2021-03-12 00:17:52,156] A new study created in memory with name: no-name-3faa476c-15cc-46a8-8928-94d6597b0a37
[I 2021-03-12 00:17:52,159] Trial 0 finished with value: 3.0 and parameters: {}. Best is trial 0 with value: 3.0.
[I 2021-03-12 00:17:52,160] Trial 1 finished with value: 2.0 and parameters: {}. Best is trial 1 with value: 2.0.
[I 2021-03-12 00:17:52,162] Trial 2 finished with value: 2.1 and parameters: {}. Best is trial 1 with value: 2.0.
[I 2021-03-12 00:17:52,164] Trial 3 pruned. 
Trial: 0 epoch: 0 loss: 5 should_prune: False
Trial: 0 epoch: 1 loss: 4 should_prune: False
Trial: 0 epoch: 2 loss: 3 should_prune: False
Trial: 1 epoch: 0 loss: 5 should_prune: False
Trial: 1 epoch: 1 loss: 4 should_prune: False
Trial: 1 epoch: 2 loss: 3 should_prune: False
Trial: 1 epoch: 3 loss: 2 should_prune: False
Trial: 2 epoch: 0 loss: 5 should_prune: False
Trial: 2 epoch: 1 loss: 4 should_prune: False
Trial: 2 epoch: 2 loss: 3 should_prune: False
Trial: 2 epoch: 3 loss: 2.1 should_prune: False  # it was pruned
Trial: 3 epoch: 0 loss: 5 should_prune: False
Trial: 3 epoch: 1 loss: 4 should_prune: False
Trial: 3 epoch: 2 loss: 3 should_prune: False
Trial: 3 epoch: 3 loss: 2.2 should_prune: True
Pruning triggered at trial 3 epoch 3 with loss 2.2
thomasj02 commented 3 years ago

I looked at the code and the example and both look great. Thanks for the quick implementation!

What is the release cycle like for Optuna, i.e. when do you expect this to be available in a release?

nzw0301 commented 3 years ago

@thomasj02 Thank you for your quick check and feedback! According to this GitHub page, optuna releases a new version in a month. However, the implementation is only in my branch and I haven't sent a pull request to optuna yet because I'm not sure that this feature will be merged by dev teams. Thus it can take more time...

github-actions[bot] commented 3 years ago

This issue has not seen any recent activity.

nzw0301 commented 3 years ago

Hi @thomasj02, the feature has been merged into the main branch of optuna. The next stable release will be within one month!

github-actions[bot] commented 3 years ago

This issue has not seen any recent activity.

nzw0301 commented 3 years ago

Let me close this issue. Feel free to reopen if you have the problem.

bgeier commented 2 years ago

Thank you for this feature. I use early stopping when training and didn't realize the default behavior.

On a related note, if using train or validation loss, it seems like median pruning would favor models that converge quicker and results would be confounded by differences in convergence rate. That is, a model that converges slower may have a higher validation loss at an earlier step compared to other models but may ultimately produce a lower objective value. It seems like median pruning is biased to identify models that learn faster. Is that correct? I noticed this effect when including a graph normalization step in my GNN. Graph normalization is used to speed up learning and potentially produce better models. However, the models without graph normalization were pruned more frequently. To try and fix this issue I've added patience and validation loss, as opposed to no patience and training loss. However, I think the issue remains especially when changing parameters that affect convergence. I may be off base though, thoughts?

nzw0301 commented 2 years ago

I suppose so since the median pruning rule implicitly assumes what you described. The n_min_trials option might not resolve the problem perfectly. However, this sounds like a property of the median pruning rule. Do you know any implementation or algorithms to resolve this problem?

bgeier commented 2 years ago

I'm not familiar with any but will try to experiment with a check that looks at step-to-step change in model weights and loss.