patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
310 stars 15 forks source link

BestSoFarMinimiser behavior #33

Open ColCarroll opened 9 months ago

ColCarroll commented 9 months ago

Not sure if this is a bug or not, but BestSoFarMinimiser appears to not check the last step of the wrapped solver:

solver = optimistix.BestSoFarMinimiser(optimistix.BFGS(rtol=1e-5, atol=1e-5))
ret = optimistix.minimise(lambda x, _: (x - 3.)**2, solver, 0.)
print(ret.value, ret.state.state.y_eval)

0.0 3.0
patrick-kidger commented 9 months ago

Bother. This is actually doing what the BestSoFar* wrappers aim to do, but I can see that it's not really optimal. To be precise: during optimisation, we get a sequence of values y0, y1, y2, ..., y_{n-1}, y_n. Typically we evaluate f(y_i) and from its value determine how to choose y_{i+1}. The goal of the BestSoFar* wrappers is to return the y_i for which f(y_i) is the smallest value we ever see. However, we never actually evaluate f(y_n), the very last iterate! So the wrappers never look at that value.

I think the fix should be to perform that evaulation in postprocess:

https://github.com/patrick-kidger/optimistix/blob/b63582a60fd30d52455912338ca48c983b8aa79a/optimistix/_solver/best_so_far.py#L105-L116

so that its body looks something like:

f, aux = fn(y, args)
loss = self._to_loss(y, f)
pred = loss < state.best_loss
best_y = tree_where(pred, y, state.best_y)
best_aux = tree_where(pred, new_aux, state.best_aux)
return best_y, best_aux, {}

What do you think? If that looks reasonable then I'd be happy to take a PR on this.