google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

OSQP crashing on unexpected params #570

Open Illviljan opened 5 months ago

Illviljan commented 5 months ago

I'm trying to move over some qp code to jaxopt, but I'm struggling to understand the cryptic errors that appears to only happen in the jaxopt implementation. I've tried with other packages and these params work with those implementations.

Here's a minimal example:

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp

def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)

P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=(A_, b_),
    params_ineq=(G_, h_),
).params.primal  # Crashes with cryptic error.
# TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (2,).

# jax-0.4.23 jaxlib-0.4.23 jaxopt-0.8.3 ml-dtypes-0.3.2 opt-einsum-3.3.0
Algue-Rythme commented 5 months ago

Hi Illviljan

Sorry for the cryptic error message. The error comes from the fact that the matrix A_ = jnp.array([[]], dtype=float).T is not a valid linear operator. If you don't need equality constraints you just need to pass None to params_eq:

P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
# A_ = jnp.array([[]], dtype=float).T
# b_ = jnp.array([], dtype=float)

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=None,   # CHANGE HERE.
    params_ineq=(G_, h_),
).params.primal

Similarly, if you don't need inequality constraints just pass None to params_ineq. Thank you for your message, I just came to the realization that I forgot to document this functionnality.

Illviljan commented 5 months ago

Thank you, a quite simple fix. I maybe just need to continue with all constraints active in my larger project.

I get surprised because it seems to me that jaxopt is the odd one out since A_ is valid in other qp packages.

Using None is fine I guess, the annoying part is that jaxopt doesn't allow both constraints to be None. Other packages allows that and I think it aligns more with how I build a new solution; start simple without any constraints and make sure it works, slowly add more constraints until the solution makes sense.

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp

def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)

P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[]], dtype=float).T
h_ = jnp.array([], dtype=float)
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works
print(x)

qp = OSQP()
x = qp.run(
    params_obj=(P_, q_),
    params_eq=None,
    params_ineq=None,
).params.primal  # Unnecessarily strict crash
Algue-Rythme commented 5 months ago

That's true ; but using OSQP when you don't have constraints is overkill. In this case OSQP algorithm degenerates toward an inefficient way to solve a linear system.

As argued in the documentation you should revert to conjugate gradient in this case.