Closed a-jp closed 1 year ago
Can you please provide a short example script that you're working with? Then we can hopefully help debug and point out where you might be going wrong (or identify a bug!)
@a-jp Here's an example of how you can use JAX to build the constraint part of the lagrangian's hessian in case your constraint functions don't return scalars:
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
from jax import jit, grad, jacrev, jacfwd
import jax.numpy as np
def con(x):
return np.array((np.prod(x) - 25, np.prod(x) - 20))
# jit the function
con_jit = jit(con)
# con_hess returns the vector of constraint hessians
con_hess = jacrev(jacfwd(con_jit))
# constraint part of the Lagrangian function's hessian, i.e.
# v_1 * H_g1(x) + .... + v_m * H_gm(x),
# where H_gi is the hessian of the constraint function gi
def con_hess_vp(x, v):
con_hess_eval = con_hess(x)
num_cons = con_hess_eval.shape[0]
return np.sum(con_hess_eval * v[:num_cons, None, None], axis=0)
con_hessvp_jit = jit(con_hess_vp)
Hi. That's extremely helpful and works! Thanks so much
Is there any way to get sparse hessians? I tried the experimental sparsify function (of jax), but to no effect. Otherwise, this is of limited use for even medium sized problems.
On Wed, 30 Mar 2022 at 1:36 am, Jonathan Helgert @.***> wrote:
@a-jp https://github.com/a-jp Here's an example of how you can use JAX to build the constraint part of the lagrangian's hessian in case your constraint functions don't return scalars:
from jax.config import configconfig.update("jax_enable_x64", True)config.update('jax_platform_name', 'cpu') from jax import jit, grad, jacrev, jacfwdimport jax.numpy as np def con(x): return np.array((np.prod(x) - 25, np.prod(x) - 20))
jit the functioncon_jit = jit(con)
con_hess returns the vector of constraint hessianscon_hess = jacrev(jacfwd(con_jit))
constraint part of the Lagrangian function's hessian, i.e.# v_1 H_g1(x) + .... + v_m H_gm(x), # where H_gi is the hessian of the constraint function gidef con_hess_vp(x, v):
con_hess_eval = con_hess(x) num_cons = con_hess_eval.shape[0] return np.sum(con_hess_eval * v[:num_cons, None, None], axis=0)
con_hessvp_jit = jit(con_hess_vp)
— Reply to this email directly, view it on GitHub https://github.com/mechmotum/cyipopt/issues/150#issuecomment-1082032448, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKZONADSF2W4GGQ7IYHIQNLVCMPPBANCNFSM5RYP2P4A . You are receiving this because you are subscribed to this thread.Message ID: @.***>
The original issues seems resolved.
Hi,
New to the code. Use ipopt via pyomo a lot. Followed the issues to get the jax example working by using hessvp instead of Hess. In my example my objective returns a scalar but my equality and inequality constraints return np vectors not a scalar.
Is there any chance of a working example to show how to use cyipopt like the jax example for that case?
I get a broadcast errors when I just swap in my own functions as a first attempt. As an example I've no idea how to adapt the v[0] part of the hessvp to the case of vector valued returns for the constraints.
Any help much appreciated as I'm keen to try the library.
Many thanks, Andy