mechmotum / cyipopt

Cython interface for the interior point optimzer IPOPT
Eclipse Public License 2.0
236 stars 54 forks source link

New user: equil and inequil returning vectors #150

Closed a-jp closed 1 year ago

a-jp commented 2 years ago

Hi,

New to the code. Use ipopt via pyomo a lot. Followed the issues to get the jax example working by using hessvp instead of Hess. In my example my objective returns a scalar but my equality and inequality constraints return np vectors not a scalar.

Is there any chance of a working example to show how to use cyipopt like the jax example for that case?

I get a broadcast errors when I just swap in my own functions as a first attempt. As an example I've no idea how to adapt the v[0] part of the hessvp to the case of vector valued returns for the constraints.

Any help much appreciated as I'm keen to try the library.

Many thanks, Andy

brocksam commented 2 years ago

Can you please provide a short example script that you're working with? Then we can hopefully help debug and point out where you might be going wrong (or identify a bug!)

jhelgert commented 2 years ago

@a-jp Here's an example of how you can use JAX to build the constraint part of the lagrangian's hessian in case your constraint functions don't return scalars:

from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')

from jax import jit, grad, jacrev, jacfwd
import jax.numpy as np

def con(x):
    return np.array((np.prod(x) - 25, np.prod(x) - 20))

# jit the function
con_jit = jit(con)

# con_hess returns the vector of constraint hessians
con_hess = jacrev(jacfwd(con_jit))

# constraint part of the Lagrangian function's hessian, i.e.
# v_1 * H_g1(x) + .... + v_m * H_gm(x), 
# where H_gi is the hessian of the constraint function gi
def con_hess_vp(x, v):
    con_hess_eval = con_hess(x)
    num_cons = con_hess_eval.shape[0]
    return np.sum(con_hess_eval * v[:num_cons, None, None], axis=0)

con_hessvp_jit = jit(con_hess_vp) 
a-jp commented 2 years ago

Hi. That's extremely helpful and works! Thanks so much

patrickocal commented 2 years ago

Is there any way to get sparse hessians? I tried the experimental sparsify function (of jax), but to no effect. Otherwise, this is of limited use for even medium sized problems.

On Wed, 30 Mar 2022 at 1:36 am, Jonathan Helgert @.***> wrote:

@a-jp https://github.com/a-jp Here's an example of how you can use JAX to build the constraint part of the lagrangian's hessian in case your constraint functions don't return scalars:

from jax.config import configconfig.update("jax_enable_x64", True)config.update('jax_platform_name', 'cpu') from jax import jit, grad, jacrev, jacfwdimport jax.numpy as np def con(x): return np.array((np.prod(x) - 25, np.prod(x) - 20))

jit the functioncon_jit = jit(con)

con_hess returns the vector of constraint hessianscon_hess = jacrev(jacfwd(con_jit))

constraint part of the Lagrangian function's hessian, i.e.# v_1 H_g1(x) + .... + v_m H_gm(x), # where H_gi is the hessian of the constraint function gidef con_hess_vp(x, v):

con_hess_eval = con_hess(x)
num_cons = con_hess_eval.shape[0]
return np.sum(con_hess_eval * v[:num_cons, None, None], axis=0)

con_hessvp_jit = jit(con_hess_vp)

— Reply to this email directly, view it on GitHub https://github.com/mechmotum/cyipopt/issues/150#issuecomment-1082032448, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKZONADSF2W4GGQ7IYHIQNLVCMPPBANCNFSM5RYP2P4A . You are receiving this because you are subscribed to this thread.Message ID: @.***>

moorepants commented 1 year ago

The original issues seems resolved.