google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

Constrained optimization using `projection_polyhedron` claims the polyhedron is empty when it is not. #469

Open ghost opened 1 year ago

ghost commented 1 year ago

A minimal example:

import jaxopt
import jax.numpy as jnp

def objective(x):
    return -jnp.sum(jnp.dot(jnp.array([2, 1], jnp.float32), x))

def main():
    a = jnp.array([[0, 0]], jnp.float32)
    b = jnp.array([0], jnp.float32)
    g = jnp.array([[0, 0], [0, 0]], jnp.float32)
    h = jnp.array([0, 0], jnp.float32)

    pg = jaxopt.ProjectedGradient(dummy, jaxopt.projection.projection_polyhedron, jit=False)
    result = pg.run(jnp.array([0, 0], jnp.float32), hyperparams_proj=(a, b, g, h))
    return result

The equality and inequality conditions are tautological, but running main still results in a raised error:

ValueError: The polyhedron is empty.

Other variations of equality and inequality constraints that clearly correspond to non-empty polyhedrons likewise raise the same error. Any guidance or advice would be appreciated. Thank you.

vroulet commented 1 year ago

Thank you @noam-delfina for the bug report! For the record, here is a non-trivial example that fails:

import jaxopt
import jax.numpy as jnp

A = jnp.array([[1., 0.]])
b = jnp.array([-1.])
G = jnp.array([[1., 0.],
               [0., 1.]])
h = jnp.array([0., 0.])
x = jnp.array([0., 0.])
jaxopt.projection.projection_polyhedron(x, hyperparams=(A, b, G, h), check_feasible=True)
# Raises ValueError: The polyhedron is empty.

This one has clearly non-empty relative interior and even a whole half-space satisfying the strict inequality. Weirdly enough, changing x (which should have no impact on the properties of the polyhedron) circumvents the issue.

import jaxopt
import jax.numpy as jnp

A = jnp.array([[1., 0.]])
b = jnp.array([-1.])
G = jnp.array([[1., 0.],
               [0., 1.]])
h = jnp.array([0., 0.])
x = jnp.array([-1., 0.])
jaxopt.projection.projection_polyhedron(x, hyperparams=(A, b, G, h), check_feasible=True)
# Does not raise any error.

@Algue-Rythme, as you worked on OCQP, would you be have an idea about the origin of the bug? I can check that later otherwise. Thank you again @noam-delfina