KKT conditions when the primal solution is a pytree #20

3 years ago

FerranAlet commented 3 years ago

Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

To make it easier to reproduce I modified the file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

Notice also that the test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the example?


mblondel commented 3 years ago

Thanks a lot for your interest in JAXopt! We support pytrees for equality-constrained QPs but not for general QPs yet. For the former, we do so by using linear operators / matvecs. For example, the quadratic form0.5 x^T Q x can be written as 0.5 * tree_vdot(x, matvec_Q(params_Q, x)). This reduces to the array case with matvec_Q(Q, x) =, x). There's a test using matvecs here (albeit not using pytrees). For general QPs, we will be able to support pytrees once we implement our own general solvers. In the meantime, you need to flatten your pytrees.

FerranAlet commented 3 years ago

Thanks for the fast reply!

My bad, I hadn't seen the matvec commit from yesterday and was reading stale code from a few days ago. I re-did the changes on the new code and showed my point on the test you pointed me to. I've updated my original question with the appropriate details for the matvec test.

I also wasn't very clear; adding more details:

mblondel commented 3 years ago

Thanks for the clarifications. The ideal would be a minimal script to reproduce the issue. Without one, I can only speculate the potential issues.

However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is:

Have you compared your primal and dual solutions to another solver (for instance QuadraticProgramming) by flattening your pytrees? Note that l2_optimality_error needs to be passed a tuple containing both the primal and dual solutions.

FerranAlet commented 3 years ago

Yes, I have compared my solutions. The solver returns the same tuple except that the first DeviceArray (primal solution) is within a list, as expected. That's why the test on the line above (that checks the solution is correct) passes.

I attach the code, you will see it's a super minimal change from the original one, I've marked the differences with #CHANGED comments. You'll see some changes were necessary to run it in Google colab, but not fundamental to my point. I've also marked the test that passes with #PASSES, and the one that fails #FAILS.

"""Quadratic programming in JAX."""

from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src import linear_solve
from jaxopt._src import tree_util

ArrayPair = Tuple[jnp.ndarray, jnp.ndarray]

def _check_params(params_obj, params_eq=None, params_ineq=None):
  if params_obj is None:
    raise ValueError("params_obj should be a tuple (Q, c)")
  Q, c = params_obj
  if Q.shape[0] != Q.shape[1]:
    raise ValueError("Q must be a square matrix.")
  if Q.shape[1] != c.shape[0]:
    raise ValueError("Q.shape[1] != c.shape[0]")

  if params_eq is not None:
    A, b = params_eq
    if A.shape[0] != b.shape[0]:
      raise ValueError("A.shape[0] != b.shape[0]")
    if A.shape[1] != Q.shape[1]:
      raise ValueError("Q.shape[1] != A.shape[1]")

  if params_ineq is not None:
    G, h = params_ineq
    if G.shape[0] != h.shape[0]:
      raise ValueError("G.shape[0] != h.shape[0]")
    if G.shape[1] != Q.shape[1]:
      raise ValueError("G.shape[1] != Q.shape[1]")

def _matvec_and_rmatvec(matvec, x, y):
  """Returns both matvec(x) = dot(A, x) and rmatvec(y) = dot(A.T, y)."""
  matvec_x, vjp = jax.vjp(matvec, x)
  rmatvec_y, = vjp(y)
  return matvec_x, rmatvec_y

def _solve_eq_constrained_qp(init_params,
  """Solves 0.5 * x^T Q x + c^T x subject to Ax = b.
  This solver returns both the primal solution (x) and the dual solution.

  def matvec(u):
    primal_u, dual_u = u
    mv_A, rmv_A = _matvec_and_rmatvec(matvec_A, primal_u, dual_u)
    return (tree_util.tree_add(matvec_Q(primal_u), rmv_A), mv_A)

  minus_c = tree_util.tree_scalar_mul(-1, c)

  # Solves the following linear system:
  # [[Q A^T]  [primal_var = [-c
  #  [A 0  ]]  dual_var  ]    b]
  return linear_solve.solve_cg(matvec, (minus_c, b), init=init_params,

def _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq):
  """Solve 0.5 * x^T Q x + c^T x subject to Gx <= h, Ax = b."""

  # CVXPY runs on CPU. Hopefully, we can implement our own pure JAX solvers
  # and remove this dependency in the future.
  # TODO(frostig,mblondel): experiment with `jax.experimental.host_callback`
  # to "support" other devices (GPU/TPU) in the interim, by calling into the
  # host CPU and running cvxpy there.
  import cvxpy as cp

  Q, c = params_obj
  A, b = params_eq
  G, h = params_ineq

  x = cp.Variable(len(c))
  objective = 0.5 * cp.quad_form(x, Q) + c.T @ x
  constraints = [A @ x == b, G @ x <= h]
  pb = cp.Problem(cp.Minimize(objective), constraints)
  print("Primal:", [jnp.array(x.value)])
  return ([jnp.array(x.value)], jnp.array(pb.constraints[0].dual_value), #CHANGED

def _create_matvec(matvec, M):
  if matvec is not None:
    # M = params_M
    return lambda u: matvec(M, u)
    return lambda u:, u)

def _make_quadratic_prog_optimality_fun(matvec_Q, matvec_A):
  """Makes the optimality function for quadratic programming.
    optimality_fun(params, params_obj, params_eq, params_ineq) where
      params = (primal_var, eq_dual_var, ineq_dual_var)
      params_obj = (Q, c)
      params_eq = (A, b)
      params_ineq = (G, h) or None
  def obj_fun(primal_var, params_obj):
    Q, c = params_obj
    _matvec_Q = _create_matvec(matvec_Q, Q)
    return (0.5 * tree_util.tree_vdot(primal_var[0], _matvec_Q(primal_var[0])) + #CHANGED
            tree_util.tree_vdot(primal_var[0], c)) #CHANGED

  def eq_fun(primal_var, params_eq):
    A, b = params_eq
    _matvec_A = _create_matvec(matvec_A, A)
    return tree_util.tree_sub(_matvec_A(primal_var[0]), b) #CHANGED

  def ineq_fun(primal_var, params_ineq):
    G, h = params_ineq
    return, primal_var[0]) - h #CHANGED

  return idf.make_kkt_optimality_fun(obj_fun, eq_fun, ineq_fun)

class QuadraticProgramming:
  """Quadratic programming solver.
  The objective function is::
    0.5 * x^T Q x + c^T x subject to Gx <= h, Ax = b.
    matvec_Q: a Callable matvec_Q(params_Q, u).
      By default, matvec_Q(Q, u) = dot(Q, u), where Q = params_Q.
    matvec_A: a Callable matvec_A(params_A, u).
      By default, matvec_A(A, u) = dot(A, u), where A = params_A.
    maxiter: maximum number of iterations.

  # TODO(mblondel): add matvec_G when we implement our own QP solvers.
  matvec_Q: Optional[Callable] = None
  matvec_A: Optional[Callable] = None
  maxiter: int = 1000

  def run(self,
          init_params: Optional[Tuple] = None,
          params_obj: Optional[ArrayPair] = None,
          params_eq: Optional[ArrayPair] = None,
          params_ineq: Optional[ArrayPair] = None) -> base.OptStep:
    """Runs the quadratic programming solver in CVXPY.
    The returned params contains both the primal and dual solutions.
      init_params: ignored.
      params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
      params_eq: (A, b) or (params_A, b) if matvec_A is provided.
      params_ineq: = (G, h) or None if no inequality constraints.
    Return type:
      (params, state), ``params = (primal_var, dual_var_eq, dual_var_ineq)``
    if self.matvec_Q is None and self.matvec_A is None:
      _check_params(params_obj, params_eq, params_ineq)

    Q, c = params_obj
    A, b = params_eq

    matvec_Q = _create_matvec(self.matvec_Q, Q)
    matvec_A = _create_matvec(self.matvec_A, A)

    if params_ineq is None:
      primal, dual_eq = _solve_eq_constrained_qp(init_params,
                                                 matvec_Q, c,
                                                 matvec_A, b,
      print("Primal:", [primal]) #CHANGED
      params = ([primal], dual_eq, None) #CHANGED
      params = _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq)

    # No state needed currently as we use CVXPY.
    return base.OptStep(params=params, state=None)

  def l2_optimality_error(
      params: Any,
      params_obj: ArrayPair,
      params_eq: ArrayPair,
      params_ineq: Optional[ArrayPair] = None) -> base.OptStep:
    """Computes the L2 norm of the KKT residuals."""
    pytree = self.optimality_fun(params, params_obj, params_eq, params_ineq)
    print("Pytree:", pytree) #CHANGED
    return tree_util.tree_l2_norm(pytree)

  def __post_init__(self):
    self.optimality_fun = _make_quadratic_prog_optimality_fun(self.matvec_Q,

    # Set up implicit diff.
    decorator = idf.custom_root(self.optimality_fun, has_aux=True)
    # pylint: disable=g-missing-from-attributes = decorator(

import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt import projection
# CHANGED: removed some imports so that it uses the modified QuadraticProgramming from above, not the original one
import numpy as onp

class QuadraticProgTest(jtu.JaxTestCase):

  def test_matvec_and_rmatvec(self):
    rng = onp.random.RandomState(0)
    A = rng.randn(5, 4)
    matvec = lambda x:, x)
    x = rng.randn(4)
    y = rng.randn(5)
    mv_A, rmv_A = _matvec_and_rmatvec(matvec, x, y)
    self.assertArraysAllClose(mv_A,, x))
    self.assertArraysAllClose(rmv_A,, y))

  def _check_derivative_A_and_b(self, solver, params, A, b):
    def fun(A, b):
      # reduce the primal variables to a scalar value for test purpose.
      hyperparams = dict(params_obj=params["params_obj"],
                         params_eq=(A, b),
      return jnp.sum(**hyperparams).params[0])

    # Derivative w.r.t. A.
    rng = onp.random.RandomState(0)
    V = rng.rand(*A.shape)
    V /= onp.sqrt(onp.sum(V ** 2))
    eps = 1e-4
    deriv_jax = jnp.vdot(V, jax.grad(fun)(A, b))
    deriv_num = (fun(A + eps * V, b) - fun(A - eps * V, b)) / (2 * eps)
    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

    # Derivative w.r.t. b.
    v = rng.rand(*b.shape)
    v /= onp.sqrt(onp.sum(b ** 2))
    eps = 1e-4
    deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=1)(A, b))
    deriv_num = (fun(A, b + eps * v) - fun(A, b - eps * v)) / (2 * eps)
    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

  def test_qp_eq_and_ineq(self):
    Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
    c = jnp.array([1.0, 1.0])
    A = jnp.array([[1.0, 1.0]])
    b = jnp.array([1.0])
    G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
    h = jnp.array([0.0, 0.0])
    qp = QuadraticProgramming()
    hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h))
    sol =**hyperparams).params
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
    self._check_derivative_A_and_b(qp, hyperparams, A, b)

  def test_qp_eq_only(self):
    Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
    c = jnp.array([1.0, 1.0])
    A = jnp.array([[1.0, 1.0]])
    b = jnp.array([1.0])
    qp = QuadraticProgramming()
    hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=None)
    sol =**hyperparams).params
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
    self._check_derivative_A_and_b(qp, hyperparams, A, b)

  def test_projection_hyperplane(self):
    x = jnp.array([1.0, 2.0])
    a = jnp.array([-0.5, 1.5])
    b = 0.3
    # Find ||y-x||^2 such that, a) = b.
    expected = projection.projection_hyperplane(x, (a, b))

    matvec_Q = lambda params_Q, u: u
    matvec_A = lambda params_A, u:, u).reshape(1)
    qp = QuadraticProgramming(matvec_Q=matvec_Q, matvec_A=matvec_A)
    # In this example, params_Q = params_A = None.
    hyperparams = dict(params_obj=(None, -x),
                       params_eq=(None, jnp.array([b])))
    sol =**hyperparams).params
    primal_sol = sol[0][0] #CHANGED
    self.assertArraysAllClose(primal_sol, expected) #PASSES
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) #FAILS
    print("Test passed")

  def test_projection_simplex(self):
    def _projection_simplex_qp(x, s=1.0):
      Q = jnp.eye(len(x))
      A = jnp.array([jnp.ones_like(x)])
      b = jnp.array([s])
      G = -jnp.eye(len(x))
      h = jnp.zeros_like(x)
      hyperparams = dict(params_obj=(Q, -x), params_eq=(A, b),
                         params_ineq=(G, h))

      qp = QuadraticProgramming()
      # Returns the primal solution only.

    rng = onp.random.RandomState(0)
    x = jnp.array(rng.randn(10).astype(onp.float32))
    p = projection.projection_simplex(x)
    p2 = _projection_simplex_qp(x)
    self.assertArraysAllClose(p, p2, atol=1e-4)

    J = jax.jacrev(projection.projection_simplex)(x)
    J2 = jax.jacrev(_projection_simplex_qp)(x)
    self.assertArraysAllClose(J, J2, atol=1e-5)

QPT = QuadraticProgTest() #CHANGED to run in Colab
QPT.test_projection_hyperplane() #CHANGED to run in Colab
mblondel commented 3 years ago

I think I fixed the issue in #21. There was a + sign somewhere where we should have used a tree_add instead. I also added a KKTSolution named tuple to make user code more readable. Now we can do this:

sol =**hyperparams).params

Let me know if this fixes the issue on your side as well.

FerranAlet commented 3 years ago

Yes, now it works; thanks!

FerranAlet commented 3 years ago

Actually, I just realized that the same bug is probably also happening in line 228 (and possibly 229, but I'm less sure) on that pull request. It just didn't affect our tests because they're equality-only.

mblondel commented 3 years ago

I was planning to do it when we get pytree support for general QPs in order to properly test it but we can try to be future proof I guess (just forced push).

FerranAlet commented 3 years ago

Awesome. I think preemptively solving the bug is useful since people (including myself) may use make_kkt_optimality_fun with inequalities without using QPs. Thanks for solving it!

mblondel commented 3 years ago

BTW, make_kkt_optimality_fun is not public API at the moment (it's in _src, which is supposed to be private stuff). I'm guessing you would like us to expose make_kkt_optimality_fun?

FerranAlet commented 3 years ago

Yes, that would be great; thanks!!