mlr-org / mlr3mbo

Flexible Bayesian Optimization in R
https://mlr3mbo.mlr-org.com
25 stars 1 forks source link

proposal for hook for loop function that allows plotting #139

Closed larskotthoff closed 9 months ago

larskotthoff commented 10 months ago

Initial proposal for some plotting functionality. Instead of providing this in the package itself, allow the user to specify a hook function in the MBO loop and give an example in the documentation. This avoids complex code that has to deal with all kinds of corner cases when plotting and allows the user to easily customize what they plot and how.

Really just meant as an initial, not yet perfect idea. Feedback welcome.

Would close #138.

sumny commented 9 months ago

Hi @larskotthoff and thanks for setting up this PR!

Since bbotk and mlr3tuning now also support callbacks, do you think we gain anything on top by providing an explicit hook function/callback in the loop function / optimizer here in mlr3mbo?

For example, currently it is already possible to write a callback like the following:

callback_plot = callback_optimization("plot",
  on_optimizer_after_eval = function(callback, context) {
    if (instance$archive$n_evals > 4) {  # start after initial design which is 4 points in the example below
      instance = context$instance
      objective = context$instance$objective
      surrogate = context$optimizer$surrogate
      acq_function = context$optimizer$acq_function
      xdt = instance$archive$data[instance$archive$n_evals, ]

      plot_data = data.table(x = seq(objective$domain$lower, objective$domain$upper, length.out = 1001L))
      plot_data[, y := instance$objective$eval_dt(plot_data[, "x", with = FALSE])]
      plot_data[, acq_ei := acq_function$eval_dt(plot_data[, "x", with = FALSE])]
      pred = surrogate$predict(plot_data)
      plot_data[, pred_mean := pred$mean]
      plot_data[, pred_se := pred$se]

      p1 = ggplot(plot_data, aes(x = x, y = y)) +
        geom_ribbon(aes(ymin = pred_mean - pred_se, ymax = pred_mean + pred_se), fill = "lightgray") +
        geom_line() +
        geom_line(aes(y = pred_mean), linetype = "dashed") +
        geom_point(aes(x = x, y = surrogate$predict(data.table(x = x))$mean), data = xdt, color = "red") +
        geom_point(aes(x = x, y = y), data = instance$archive$data[- instance$archive$n_evals, ], color = "darkgreen") +
        theme(axis.title.x = element_blank(), axis.text.x = element_blank(), axis.ticks.x = element_blank()) +
        theme_minimal()

      p2 = ggplot(plot_data, aes(x = x, y = acq_ei)) +
        geom_line() +
        geom_point(aes(x = x, y = acq_ei), data = xdt, color = "red") +
        theme_minimal()

      p = plot_grid(p1, p2, ncol = 1, align = "v")
      ggsave(p, file = paste("/tmp/mbo-", instance$archive$n_evals, ".pdf", sep = ""), width = 10, height = 5)
    }
  }
)

(similarly one could write a callback_tuning)

We can then trigger this callback for example after or before the evaluation of a point/batch. Here we will trigger it on_optimizer_after_eval (see argument above):

obfun = ObjectiveRFun$new(
  fun = function(xs) 2 * xs$x * sin(14 * xs$x),
  domain = ps(x = p_dbl(lower = 0, upper = 1)),
  codomain = ps(y = p_dbl(tags = "minimize")))

instance = OptimInstanceSingleCrit$new(
  objective = obfun,
  terminator = trm("evals", n_evals = 10),
  callbacks = callback_plot)

surrogate = srlrn(default_gp())
acqfun = acqf("ei")
acqopt = acqo(
  optimizer = opt("nloptr", algorithm = "NLOPT_GN_ORIG_DIRECT"),
  terminator = trm("stagnation", iters = 100, threshold = 1e-5))

optimizer = opt("mbo",
  loop_function = bayesopt_ego,
  surrogate = surrogate,
  acq_function = acqfun,
  acq_optimizer = acqopt)

optimizer$optimize(instance)

this then mimics the behavior of your hook function/callback above, e.g., we get the desired sequence of plots:

image

I guess I should start writing a section regarding callbacks and mlr3mbo in the vignette soon and maybe a gallery post :)

larskotthoff commented 9 months ago

Thanks, didn't know that you can have callbacks!

The only case in which this wouldn't work as well is for an error condition (e.g. random interleaving), when the optimizer callback would "skip" an iteration. I don't think that this is really an issue, in particular given that it's much easier to implement and maintain with the optimizer callback.

I'll close here and look forward to the improved documentation showing an example :)