google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

Timeline for JaxOpt migration #977

Open Joshuaalbert opened 1 month ago

Joshuaalbert commented 1 month ago

Hello, what's the current roadmap for jaxopt migration into optax? Will the scope of jaxopt be maintained, or will a trimming/expansion of features happen?

jlperla commented 2 weeks ago

Another question I had on the jaxopt is the extent to which the differentiable optimization features will be ported?

That is, will there be easy ways to get the jvp and vjp linearization around the argmin or min of the optimization process? (e.g. https://jaxopt.github.io/stable/implicit_diff.html#implicit-diff) Of course, none of this is specific to the optimization algorithm itself, but with training-loop style optimization it isn't clear where the custom jvp/vjp rules should be placed. You don't want to recurse through the training loop itself to differentiate.

For reference, in jaxopt this is done with a wrapper around the solver itself which registers the @jax.custom_vjp in things like https://github.com/google/jaxopt/blob/4a50198711e53ed0c25f2be5394825092c2427db/jaxopt/_src/implicit_diff.py#L204

Joshuaalbert commented 2 weeks ago

I hope there will be an RFC around this because there are numerous groups interested in JAX based optimization from science (e.g. MODE consortium https://indico.cern.ch/event/1380163/) to general entry-level ML students and everything in between.

mblondel commented 1 week ago

So far, we've migrated all loss functions, all tree utilities, perturbation utilities, some projections and an optax compatible LBFGS (with backtracking line search or zoom line search).

The next big item we want to migrate is implicit differentiation. We'll start by migrating the custom_ decorators from JAXopt. This will require users to write a small wrapper on their side, but this is fairly simple; see this example.

We've also been brainstorming a solver API similar to JAXopt but this one will take more time, as we want to integrate it well with the current Optax API. Once we've figured a good solver API, it would be great to also migrate the SciPy wrappers.

Joshuaalbert commented 1 week ago

Good to hear an update. This sounds like it will drop some off the classic solvers from support e.g. LM and BFGS? Will the main focus still be mini-batched optimisation?

Re: API. Could I suggest that once you have some candidate API structures you open an RFC (request for community comment) issue for a week or so just to see if there is any useful feedback?

krzysztofrusek commented 1 week ago

@Joshuaalbert there was one PR (#777 ) introducing solver API but it was reverted.

Regarding BFGS and similar, they are covered by optimistix.

jlperla commented 1 week ago

@mblondel Having a wrapper, and even a coding pattern where you register your own VJP/JVP is fine but I am a little confused in that it sounds like you are decoupling it from the solver API. Or maybe I misunderstood?

FWIW, I think it would be completely fine if you guys gave an example along the lines of

@jax.jit
@jax.custom_jvp
def my_solver(params):
   # some closure over data/etc.
   # run the jaxopt optimizer training loop in whatever minibatch loop you with.
   # return the argmin and min values
   return min_value, argmin

@jax.jit
@my_solver.defjvp
def my_solver_fwd(params):
   # implementation to return the JVP perturbation to both
   # Assuming that the my_solver has delivered the optimum to precision required for implicit differentiation

@jax.jit
@my_solver.defjvp
def my_solver_bwd(primals, tangents):
   # implement using AD rules for implicit differentiation for VJP

Or whatever is correct, and providing the correct template code that can by copy/pasted for the AD rules (which is independent of the optimization method).

Having a solver interface that automatically does that registration/code would be nice, but a hand-coded example is good enough to start?

mblondel commented 1 week ago

This sounds like it will drop some off the classic solvers from support e.g. LM and BFGS?

BFGS should be easy to migrate using the same optax API as LBFGS. For LM, we want to implement the algorithm from scratch but it's not super high in our priority list.

Could I suggest that once you have some candidate API structures you open an RFC (request for community comment) issue for a week or so just to see if there is any useful feedback?

Good idea

Having a solver interface that automatically does that registration/code would be nice, but a hand-coded example is good enough to start?

Yep, that's the plan. We'll document how to use the custom decorators. Automatic registration will be possible when we figure out a solver API.

diegoferigo commented 1 week ago

Is there any interest in migrating jaxopt.BoxOSQP? There is some recent interest within the robotics community in the development of JAX-based controllers to be used in closed-loop simulations that can be executed at scale on hardware accelerators (end-to-end differentiable).

As far as I know, this OSQP implementation is the only quadratic programming framework currently available for the JAX ecosystem.