patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
306 stars 15 forks source link

Optimization with state #83

Open f0uriest opened 6 days ago

f0uriest commented 6 days ago

Suppose I have an optimization problem where the function to be optimized requires solving some expensive iterative sub-problem. It would be efficient if I could re-use the solution from the last step of the optimizer to warm start the next step. Something like

def myfun(x, y0):
    y = expensive_subproblem(x, initial_guess=y0)
    return some_other_func(x,y), y

And a sketch of what I'm doing now:

for _ in range(maxiter):
    ((f, y), df) = jax.value_and_grad(myfun, argnums=0, has_aux=True)(x, y)
    x = optimizer_step(x, f, df)

Is something like this possible in optimistix?

patrick-kidger commented 5 days ago

This can be done by defining a custom solver, which wraps your current solver of choice. Solvers are the API level at which we pass state between steps.

For an example take a look at the best-so-far solvers, which do something very similar -- they wrap an existing solver, call that on every step, but additionally pass around some additional state.