google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
918 stars 64 forks source link

Reusing forward computation for implicit diff #7

Open mblondel opened 3 years ago

mblondel commented 3 years ago

With some solvers (e.g. Newton's method), it is possible to reuse some of the computations for solving the implicit differentiation linear system more efficiently. Since @custom_root and @custom_fixed_point accept a solve argument for specifying a linear solver, the output of the forward pass solver could be given as an argument to solve.

shoyer commented 3 years ago

A good example of this might be (approximate) matrix factorizations and/or preconditioners, which we need both for iterations of Newton's method, and for backward solves.

GeoffNN commented 2 years ago

I've been thinking about this for equality constrained QP solving. Preconditioners might accelerate this a lot. Ideally, we'd like to compute the preconditioner once, and 1) Share it between forward and backward pass 2) Share it across k outer loop iterations, e.g. if the jaxopt solver solves an inner optimization problem.

I'm looking into preconditioning through Ruiz Equilibration, described in the OSQP paper, section 5.1. I think we could wrap the EqualityConstrainedQP solver in a PreconditionedSolver class. The PreconditionedSolver class would have an API like

@dataclass
class PreconditionedEqualityConstrainedQP:

  qp_solver: EqualityConstrainedQP

  def init_precond_params(self, params_obj, params_eq):
    # Ruiz equilibration code
    # This returns (c, D, E) in the paper's notation, ie necessary items for the back and forth preconditioning transformation. 
    return precond_params

  def precondition_problem(params_obj, params_eq):
    # Applies preconditioning
    return precond_params_obj, precond_params_eq

  def recover_solution(self, precond_params, precond_solution):
    return solution # of the original, unpreconditioned problem

  def run(self, params_obj, params_eq, precond_params, **kwargs):
    precond_params_obj, precond_params_eq = self.precondition_problem(params_obj, params_eq, precond_params)
    precond_solution =  self.qp_solver.run(precond_params_obj, precond_params_eq, **kwargs)
    return self.recover_solution(precond_params, precond_solution)

Pros:

Cons:

Let me know what you think! I can work on this next week if you think it's a good idea.