google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
919 stars 64 forks source link

Twice differentiability of OSQP #419

Open nomoriel opened 1 year ago

nomoriel commented 1 year ago

Hi Thank you for this very cool repo and, sorry, I'm new to "Issues" and to jax,

I'm getting an error of JaxStackTraceBeforeTransformation when trying to take the derivative with a ProjectedGradient solver involved. The solver itself runs flawlessly and derivatives seem to work without the solver. I want to later use outer_with_proj (with a slightly different objective) for implicit differentiation of theta.

Any advice?

The code:

import numpy as np
import jax.numpy as jnp
import jax
import random
import matplotlib.pyplot as plt

random.seed(0)

# parameters
m = 3
n = 25
nm = n*m
d = 2
d0 = 5
d1 = 5
sshape = (d0,d1,m)

# polyhedron constraints
# Az = b
A = np.zeros((n, nm))
for i in range(n):
    A[i, i*m:(i+1)*m] = 1

b = np.ones(n)

# Jz <= h
J = np.zeros((2*nm, nm))
J[:nm, :] = -np.eye(nm)
J[nm:, :] = np.eye(nm)
h = np.zeros(2*nm)
h[:nm] = 0
h[nm:] = 1

# initial state
g0 = np.random.rand(n, m)
g0 = (g0 / np.sum(g0, 1, keepdims=True)).flatten()

# auxilary params for objective
L = np.stack([np.repeat(range(d0), d1), np.tile(range(d1), d0)]).T

x = L[...,0]
y = L[...,1]

library = np.stack([x**0, x, y,], axis=-1)
library = library / library.max(axis=0, keepdims=True) 
l = library.shape[1]

# defining a gradient
params = np.zeros((l, m))
params[1,0] = -1
params[0,0] = 1
params[2,1] = -1
params[0,1] = 1
params[0,2] = 1

grads = (library @ params)

plt.imshow(grads.reshape(sshape))

# convert to jax
A = jnp.asarray(A)
b = jnp.asarray(b)
J = jnp.asarray(J)
h = jnp.asarray(h)
g0 = jnp.asarray(g0)

library = jnp.asarray(library)
params = jnp.asarray(params)
grads = jnp.asarray(grads.flatten())

# problem setup
from jaxopt import projection
from jaxopt import ProjectedGradient

lr = 0.001

def obj(g, theta):
    """
    Maximize the sum of the log of the weighted sum of the gradients
    :param g: flattened G nxm
    :param theta: flattened params lxm
    """

    G = jnp.reshape(g, (n,m))
    w = jnp.einsum('nl,lm->nm', jnp.reshape(grads, (n,l)), jnp.reshape(theta, (l,m)))
    return -jnp.sum(jnp.log(jnp.sum(jnp.multiply(w, G), 0)))

def outer_with_proj(g0, theta):
    """
    Given theta, compute error from gradients where each sample is within the simplex
    :param g0: flattened G nxm
    :param theta: flattened params lxm
    """
    def proj(p,C):
       return projection.projection_polyhedron(p, C, check_feasible=False)

    solver = ProjectedGradient(fun=obj,
                                projection=proj, 
                                maxiter=10, 
                                implicit_diff=True,
                                )

    g_fit = solver.run(g0, (A, b, J, h), theta).params
    return jnp.mean((g_fit - grads) ** 2)

def outer_no_proj(g0, theta):
    """
    Given theta, compute error from gradients
    :param g0: flattened G nxm
    :param theta: flattened params lxm
    """
    g_fit = g0 - lr * jax.grad(obj, argnums=0)(g0, theta)
    return jnp.mean((g_fit - grads) ** 2)

eps = 1e-2
theta = params + eps
print(outer_no_proj(g0, theta))
print(outer_with_proj(g0, theta))

print(jax.grad(outer_no_proj, argnums=1)(g0, theta))
print(jax.grad(outer_with_proj, argnums=1)(g0, theta))

The error:

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File [~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/runpy.py:196](https://file+.vscode-resource.vscode-cdn.net/Users/nomo/PycharmProjects/pareto_dynamic/notebooks/~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/runpy.py:196), in _run_module_as_main(***failed resolving arguments***)
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File [~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/runpy.py:86](https://file+.vscode-resource.vscode-cdn.net/Users/nomo/PycharmProjects/pareto_dynamic/notebooks/~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/runpy.py:86), in _run_code(***failed resolving arguments***)
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File [~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/site-packages/ipykernel_launcher.py:17](https://file+.vscode-resource.vscode-cdn.net/Users/nomo/PycharmProjects/pareto_dynamic/notebooks/~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/site-packages/ipykernel_launcher.py:17)
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File [~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/site-packages/traitlets/config/application.py:1043](https://file+.vscode-resource.vscode-cdn.net/Users/nomo/PycharmProjects/pareto_dynamic/notebooks/~/opt/miniconda3/envs/pareto_dyn/lib/python3.10/site-packages/traitlets/config/application.py:1043), in Application.launch_instance(***failed resolving arguments***)
   1042 app.initialize(argv)
-> 1043 app.start()
...
    372     *(_flatten(params.transpose()) + x_cotangent),
    373     const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
    374 # drop aux values in cotangent computation

AssertionError:

Thanks again!

mblondel commented 1 year ago

Thanks for the bug report. I don't see any mistake in your code a priori. I tried to run your code and got a linear transpose rule error:

Traceback (most recent call last):
  File "/Users/mblondel/Desktop/playground/pg_bug.py", line 122, in <module>
    print(jax.grad(outer_with_proj, argnums=1)(g0, theta))
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 659, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 741, in value_and_grad_f
    g = vjp_py(lax_internal._one(ans))
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/tree_util.py", line 303, in __call__
    return self.fun(*args, **kw)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 2183, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/tree_util.py", line 303, in __call__
    return self.fun(*args, **kw)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 146, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 751, in _custom_lin_transpose
    cts_in = bwd(*res, *cts_out)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 683, in <lambda>
    bwd_ = lambda *args: bwd(*args)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
    vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
  File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
    u = solve(matvec, v)
  File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
    Ab = rmatvec(b)  # A.T b
  File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/linear_solve.py", line 145, in <lambda>
    return lambda y: transpose(y)[0]
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 2353, in transposed_fun
    in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/lax/control_flow/solves.py", line 370, in _linear_solve_transpose_rule

Actually, you're trying something a bit advanced. The issue is that projection_polyhedron calls our OSQP solver. Therefore, computing the Jacobian of projection_polyhedron itself requires implicit differentiation: there is a nested implicit differentiation inside the implicit differentiation you're trying to do and it seems like things don't compose correctly at the moment. I think second-order derivatives probably do not work correctly in OSQP at the moment.

CC: @Algue-Rythme @froystig

mblondel commented 1 year ago

As a temporary workaround, you can set implicit_diff=False in ProjectedGradient. This way, there is only one implicit differentiation involved, the one of projection_polyhedron.

nomoriel commented 1 year ago

Gotcha! The workaround works :) thank you very much!