Open ColCarroll opened 9 months ago
Bother. This is actually doing what the BestSoFar*
wrappers aim to do, but I can see that it's not really optimal. To be precise: during optimisation, we get a sequence of values y0, y1, y2, ..., y_{n-1}, y_n
. Typically we evaluate f(y_i)
and from its value determine how to choose y_{i+1}
. The goal of the BestSoFar*
wrappers is to return the y_i
for which f(y_i)
is the smallest value we ever see. However, we never actually evaluate f(y_n)
, the very last iterate! So the wrappers never look at that value.
I think the fix should be to perform that evaulation in postprocess
:
so that its body looks something like:
f, aux = fn(y, args)
loss = self._to_loss(y, f)
pred = loss < state.best_loss
best_y = tree_where(pred, y, state.best_y)
best_aux = tree_where(pred, new_aux, state.best_aux)
return best_y, best_aux, {}
What do you think? If that looks reasonable then I'd be happy to take a PR on this.
Not sure if this is a bug or not, but
BestSoFarMinimiser
appears to not check the last step of the wrapped solver: