google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
930 stars 65 forks source link

OSQP fails to solve feasible QP #342

Open deasmhumhna opened 2 years ago

deasmhumhna commented 2 years ago

I'm trying to use jaxopt.OSQP as part of the projection step of another training algorithm. The QP is

\begin{equation}
\begin{split}
\text{min} & \quad ||\theta-\theta_0||^2 \\
\text{st.} & \lim_{x \to\infty}f(x, y, \theta)=0 \\
    & f(x, 1, \theta)=0 \\
    & y \frac{\partial f}{\partial y}(x,y,\theta) \geq -1
\end{split}
\end{equation}

where $f(x,y,\theta)=\sum{m,n}\theta{mn}T_n(x^{-1})T_m(y)$ and $T_k$ is the k-th Chebyshev polynomial. Through linear mapping of the feasible domains of $\rho=x^{-1}$ and $y$ onto $[-1, 1]$, the two equality constraints become direct linear constraints on $\theta$ and the inequality constraint can be made linear through discretization over a grid. The program is necessarily feasible since the zero function is feasible. However, jax.OSQP/jax.BoxOSQP doesn't even get close to a feasible solution. Moreover, when started at a feasible solution, it returns an infeasible solution. I'll share a cleaned-up version of my code shortly.

mblondel commented 2 years ago

Indeed, a code snippet to reproduce the issue would be great.

deasmhumhna commented 2 years ago

Indeed, a code snippet to reproduce the issue would be great.

https://colab.research.google.com/drive/1c9kkaAwrysPFq5bEPBhJrzQDC6sGT0sF?usp=sharing

Algue-Rythme commented 2 years ago

Thank you; I will take a look in the end of the week.

Algue-Rythme commented 1 year ago

I took a look; I will be honest: I couldn't understand the error. I don't know if it's purely numeric, a failure of our implementation, or a failure of yours.

The code you provided is not very jit-friendly so I make iterative experiments kind of hard because they run slowly.

Can you clarify what you mean by "fails to solve feasible QP" ? Is it returning an error code in state.status (wrongfully detecting termination with BoxOSQP.PRIMAL_INFEASIBLE, BoxOSQP.DUAL_INFEASIBLE), or is it failing to converge within time budget (return BoxOSQP.UNSOLVED ?

I am no expert of Chebychev polynomial interpolations. However I can say that BoxOSQP will not like operators that are not linears. Can you promise that the oprator is linear, and that the corresponding A linear operator remains constant during ? It seems it is the case solely from the maths. But in the case of recursive definitions (like Chebychev polynomials) the numerical errors can be amplified. I am not talking about approximation quality and convergence speed of the interpolation, but about the numerical errors occurring in float32 arithmetic. I tried with float64 precision and couldn't observe convergence within reasonable time scale.

Can you certify that you code is numerically stable and implement a valid linear operator that remains constant during training ? Not on toy examples but specifically on the big instance you are testing (i.e x of shape 100, theta_0 a random matrix of size 30x30 with gaussian weights). Random Gaussian matrices have notoriously high condition number (they were a frequent cause of unit test failures). BoxOSQP is resilient to ill-posed problems to some extent, but they may fail unexpectedly when some linear systems are hard to invert. The sigma parameter controls that. Also do not take a too small tolerance tol, below some threshold it is practically zero. Also when looking at your maths it looks like f is a real function. But when looking at your code it is written as if it takes values in space of dimension 100. As a consequence the derivative is actually of Jacobian of size 100x100. I guess in your notations you trade the notations of the function for an array of value of size 100 that corresponds to a discretization on a grid. Numerically, lot of things can go wrong including in float64 arithmetic.

Do you have way to ensure your computation is numerically stable in instances this big ? Or ideally a toy example on which the failure happens.

One quick way to sort out all this would be to actually use the official OSQP implementation. You can retrieve the true matrix A from your code by taking the Jacobian of matvec_A operation. Even if it works in this setting, it does not guarantee that computing A on the fly and pre-computing are similar (numerical errors are the hell).

Algue-Rythme commented 1 year ago

Your benchmark has 900 primal variables and more than 10,000 constraints without obvious sparsity pattern to exploit. Ours unit tests do not cover instances this big. We should have conduct properly large scale benchmarks from original papers [1,2] , but we did not.

[1] Stellato, B., Banjac, G., Goulart, P., Bemporad, A. and Boyd, S., 2020. OSQP: An operator splitting solver for quadratic programs. Mathematical Programming Computation, 12(4), pp.637-672. [2] Schubiger, M., Banjac, G. and Lygeros, J., 2020. GPU acceleration of ADMM for large-scale quadratic programming. Journal of Parallel and Distributed Computing, 144, pp.55-67.

@mblondel do you have any thought on the question ? Is it frequent to see BoxOSQP fail on instances this big ? The current implementation explicitly disabled the possibility to tune manually the tolerance (and maxiter) of internal linear solvers with the following code:

if self.eq_qp_solve.lower() == 'cg':
      self._eq_qp_solve_impl = OSQPIndirectSolver(self.matvec_Q, self.matvec_A,
                                                  tol=1e-7 * self.tol)
    elif self.eq_qp_solve.lower() == 'cg+jacobi':
      self._eq_qp_solve_impl = OSQPIndirectSolver(self.matvec_Q, self.matvec_A,
                                                  tol=1e-7 * self.tol,
                                                  jacobi_preconditioner=True)
    elif self.eq_qp_solve.lower() == 'lu':
      self._eq_qp_solve_impl = OSQPLUSolver()
    else:
      raise ValueError(f"Unknown solver '{self.eq_qp_solve}'.")

I wonder if this is sufficient to solve issues of OP.

deasmhumhna commented 1 year ago

Thanks for your response.

I refactored the code for chebval and chebder to be more jit friendly; the original functions were direct copies from the current NumPy implementations with minor changes to work with JAX. I also added a certificate of sorts; testing if jvp/vjp for two different random inputs produce the same result given the same tangents.

There might be a numerical issue, though chebval is based on the Clenshaw algorithm which is backward stable. I don't know about the stability of chebder. Using jax.jacobian, we can see the jacobian is upper triangular with integer values. jacrev seems to recover these exactly while jacfwd starts showing errors once the input size gets large enough.

I could also restate the problem using the barycentric formulation of Chebyshev interpolation, which uses the values of $f$ at the Chebyshev nodes directly and is also stable. This formulation is also linear in the function values, and I might be able to replace the grid constraint with O(params.size) constraints based on the first/second derivative (still linear in $f$). I'll have to look into the mathematics tomorrow.

It could also be that ADMM is just slow! When implementing a similar QP solver in Go, I used the linesearch-based acceleration described here. I found a prototype I made in NumPy for debugging and am converting it to JAX. BoxOSQP seems to reduce the infeasibility, so it's probably just reaching its budget.

Still, it does seem to take an optimal solution, a feasible starting point, and return a much worse solution. This is not ideal. Using init_params = qp.init_params(params, ...) helps so I feel this is just an issue with the default (zero) initialization. It would be nice to be able to give an initialization in a simple form rather than a KKTState.

Algue-Rythme commented 1 year ago

Using init_params = qp.init_params(params, ...) helps so I feel this is just an issue with the default (zero) initialization. It would be nice to be able to give an initialization in a simple form rather than a KKTState.

That is good to know ! Indeed when you only have access to primal variables the preferred workflow is:

kkt_sol = qp.init_params(primal_variables, ...)
init_state = qp.init_state(kkt_sol, ...)
qp.run(kkt_sol, ...)

If you know the value of dual variables, you can set them yourself in kkt_sol. If you don't, I wonder if an heuristic can do the trick. Maybe we can "detect" active constraints (e.g equality constraints at least, and inequality constraints within some tolerance), and then apply solution polishing (also in page 11 of the paper). This is a major missing ingredient compared to the original OSQP's implementation. It was originally used to refine the accuracy at faster rate than ADMM once active constraints have been found. I realize it can be used as initialization of dual variables too.