google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Gauss-Newton and Levenberg-Marquardt #920

Open gbruno16 opened 3 months ago

gbruno16 commented 3 months ago

In this fork I'm trying to implement the Gauss-Newton and the Levenberg-Marquardt methods for the Optax library. The primary objective is to provide a flexible Gauss-Newton transformation that offers options for selecting the damping parameter, the solver, and whether to consider the normal equations. Additionally, this transformation enables solving least squares problems by just providing the jvp of the residuals function and can handle compositional problems by specifying the hvp of the outer function.

A simple usage example for the Gauss-Newton optimizer:

import jax
import optax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def f(x):
  return jnp.sqrt(2) * jnp.array([10 * (x[1] - x[0]**2), (1 - x[0])])

params = jnp.array([-1.2, 1])
print('Initial objective function: ', 0.5*jnp.sum(f(params)**2))

solver = optax.gauss_newton()
opt_state = solver.init(params)

for _ in range(5):
    residuals, inner_jvp = jax.linearize(f, params)
    updates, opt_state = solver.update(residuals, opt_state, params, inner_jvp=inner_jvp)
    params = optax.apply_updates(params, updates)
    print('Objective function: {:.2E}'.format(0.5*jnp.sum(f(params)**2)))

The Gauss-Newton transformation could serve as building block for constructing more sophisticated optimization solvers. As an illustration, I have incorporated the trust region algorithm implemented in Jaxopt (algorithm 6.18 in “Introduction to Optimization and Data Fitting”, K. Madsen & H. B. Nielsen) into the scale_by_madsen_trust_region transformation. As a consquence we can seamlessly obtain the Levenberg-Marquardt method by composing it with the Gauss-Newton transformation described earlier.

The previous example becomes:

import jax
import optax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

def f(x):
  return jnp.sqrt(2) * jnp.array([10 * (x[1] - x[0]**2), (1 - x[0])])

params = jnp.array([-1.2, 1])
print('Initial objective function: ', 0.5*jnp.sum(f(params)**2))

solver = optax.levenberg_marquardt(init_damping_parameter=1.0)
opt_state = solver.init(params)

for _ in range(15):
    updates, opt_state = solver.update(opt_state, params, residuals_fn=f)
    params = optax.apply_updates(params, updates)
    print('Objective function: {:.2E}'.format(0.5*jnp.sum(f(params)**2)))

This is still a draft and will require more time, but feedbacks and suggestions for improvement are greatly appreciated. Please feel free to share your thoughts on the implementation and suggest any enhancements or modifications.