google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.63k stars 174 forks source link

Feature Request: Second-Order Optimization Methods in Optax #1014

Closed lvjonok closed 1 month ago

lvjonok commented 1 month ago

Description

I am interested in JAX-based optimization and would like to inquire about the potential implementation of second-order optimization methods such as Sequential Quadratic Programming (SQP) in Optax.

Challenges

Implementing SQP involves the use of Hessians, which are currently not a part of the optax.GradientTransformation interface.

Request

Could you provide insights or plans regarding the support for second-order optimization methods in Optax? Specifically, how would the integration of Hessians be approached within the existing framework?

Thank you for your consideration.

vroulet commented 1 month ago

Hello @lvjonok,

Great question. We discussed this some time ago. A PR (#817) was implementing oracles like hvp/gnvp (Hessian vector product and Gauss-Newton vector product respectively) etc... It was stopped because of a debate on the signatures of these functions (jvp, vjp follow closely the vocabulary of differential geometry, not clear how to properly keep signatures of hvp/gnvp that follow such vocabulary).

Coming back to your question, as long as we simply use hvps (no gnvps), one may create second order optimization schemes by giving to the GradientTransformWithExtraArgs the hvp as an extra arg. The transformation can take care of computing e.g. the Newton direction by solving the linear system with calls to the hvp uniquely.

Note that it would be a priori a pity to compute the full hessian and invert it rather than making calls to hvps (see the classical literature on Hessian-Free optimization [1]).

As a first step you may attempt to create a simple Newton solver as a GradientTransform.

Finally, if you are dealing with deterministic objectives, you may want to take a look at optimistix (https://docs.kidger.site/optimistix/).

[1] J. Martens, Deep learning via Hessian-free optimization, ICML 2010

lvjonok commented 1 month ago

Thank you, @vroulet!

You have very much answered my question. I will look into the literature and explore the suggested approaches. If I come across any additional methods or insights that could be beneficial for the repository, I will certainly respond and contribute.