patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Extracting intermediate function values/ losses from the solve #52

Open itk22 opened 3 months ago

itk22 commented 3 months ago

Dear optimistix team,

First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!

I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan would definitely improve the performance over using a for but I was wondering if there is a simpler method to extract this information in optimistix.

def rastrigin(x, args):
    A = 10.0
    y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
    return y

# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
    sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
    return sol.value

def multi_step_solve(solver, y0):
    # This is much less efficient, but it's easy to extract losses
    current_sol = y0
    for i in range(2_000):
        current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
    return current_sol
patrick-kidger commented 3 months ago

Hmm. I don't think we offer that at the moment!

FWIW if this is just for debugging purposes then you could add jax.debug.print statements to the input or output of your function.

If you really want to interact with the history programmatically then (a) I'm quite curious what the use-case is, but also (b) we could probably add an additional optx.minimise(..., saveat=...) argument without too much difficulty.

itk22 commented 3 months ago

Hi, Thanks for the quick response. I am currently trying to use optimistix to implement Model Agnostic Meta-Learning and its implicit version (https://arxiv.org/abs/1909.04630). I was considering using a multi_step_solve approach from above for the outer meta-learning loop because it makes the training very quick. However, I need to be able to monitor the meta-losses, which is why I was looking into the less efficient single_step_solve. The saveat option would be very useful!

On a side note, I had a look at the interactive stepping example, and I thought that it could be useful for solvers to have an update method for performing a single optimisation step, similar to JAXOpt. What do you think?

patrick-kidger commented 3 months ago

Makes sense. I'll mark this as a feature request for a saveat option. (And I'm sure we'd be happy to take a PR on this.)

For performing a single optimisation steps, then I think we already have this, as the step method on the solver?

itk22 commented 3 months ago

Thanks Patrick, I will try to make a PR on this :)