google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

Implicitly differentiate the KKT conditions #539

Open amdee opened 1 year ago

amdee commented 1 year ago

Hi,

I am currently learning how to use Jaxopt and am trying to adapt the code below to utilize its features. This code is originally from research paper and was ported from PyTorch to Jax. In the forward method, a root-finding problem is solved using the bracketing method. Meanwhile, the backward method relies on implicit differentiation through the KKT conditions.

Currently, I am only making use of Jaxopt's root-finding capabilities. I have a specific question: How can I employ Jaxopt to perform differentiation through the KKT conditions?

Please find my current implementation below:"

import jax
import jax.numpy as jnp
import flax.linen as nn
from jaxopt import Bisection
import numpy as np

@jax.custom_vjp
def LML_jax(x, N, eps, n_iter, branch=None, verbose=0):
    y, res = lml_forward(x, N, eps, n_iter, branch, verbose)
    return y, res

def f(nu, x, N):
    return jnp.sum(jax.nn.sigmoid(x + nu)) - N

def lml_forward(x, N, eps, n_iter, branch, verbose):
    branch = branch if branch is not None else 10 if jax.devices()[0].platform == 'cpu' else 100
    nx = x.shape[0]
    if nx <= N:
        return jnp.ones(nx, dtype=x.dtype), None

    x_sorted = jnp.sort(x)[::-1]
    nu_lower = -x_sorted[N-1] - 7.
    nu_upper = -x_sorted[N] + 7.

    # Using Bisection from jaxopt
    bisection = Bisection(optimality_fun=f, lower=nu_lower, upper=nu_upper, tol=eps, check_bracket=False)
    sol = bisection.run(x=x, N=N)
    nu = sol.params

    y = jax.nn.sigmoid(x + nu)
    return y, (y, nu, x, N)

def lml_backward(res, grad_output):
    y, nu, x, N = res
    if y is None:
        return (jnp.zeros_like(x), None, None, None, None, None)

    Hinv = 1. / (1. / y + 1. / (1. - y))
    dnu = jnp.sum(Hinv * grad_output) / jnp.sum(Hinv)
    dx = -Hinv * (-grad_output + dnu)
    return (dx, None, None, None, None, None)

LML_jax.defvjp(lml_forward, lml_backward)

class LML(nn.Module):
    N: int = 1
    eps: float = 1e-4
    n_iter: int = 100
    branch: int = None
    verbose: int = 0

    @nn.compact
    def __call__(self, x):
        return LML_jax(x, N=self.N, eps=self.eps, n_iter=self.n_iter, branch=self.branch, verbose=self.verbose)

if __name__ == '__main__':
    m = 10
    n = 2
    np.random.seed(0)
    x = np.random.random(m)
    x_jax_unbatched = jnp.array(x)
    x_jax_batched = jnp.stack([x_jax_unbatched, x_jax_unbatched])
    x = jnp.stack([x, x])
    model = LML(N=n)
    key1, key2 = jax.random.split(jax.random.PRNGKey(1))
    dummy_input = jax.random.normal(key1, (n, m))
    params = model.init(jax.random.PRNGKey(0), dummy_input)
    LML_state = model.bind(params)
    lml = lambda x_input: LML_state(x_input)[0]

    y_unbatched = lml(x_jax_unbatched)
    y_batched = jax.vmap(lml)(x_jax_batched)
    y_unbatched_check = np.array(y_unbatched, copy=False)
    y_batched_check = np.array(y_batched, copy=False)

    vyo_unbatched, dyo_unbatched = jax.value_and_grad(lml)(x_jax_unbatched)
    vyo_batched, dyo_batched = jax.vmap(jax.value_and_grad(lml))(x_jax_batched)
    print(f"value of y vyo_unbatched: {vyo_unbatched}\ngradient of y dy0 vyo_unbatched: {dyo_unbatched}")
    print(f"\nvalue of y vyo_batched: {vyo_batched}\ngradient of y dy0 vyo_batched: {dyo_batched}") 
zaccharieramzi commented 1 year ago

Hi @amdee ,

When you are using solvers directly from jaxopt, like what you are doing with Bisection you don't need to worry about how to differentiate through the KKT conditions: you can use the solver as is and think of it as a differentiable operation. What you can play on is whether to use implicit differentiation or not (in your case you probably want to), and how to solve the inverse jacobian (but you can also keep the defaults for this). In other words you shouldn't need to define a custom vjp. I think in your case the problem is that you are using non differentiable operations inside lml_forward like the sorting before even creating the optimizer. You also cannot differentiate w.r.t. parameters of the optimizer as what you are doing here with lower and upper: basically you should be agnostic to optimizer's parameters in your differentiation.

Hope this helps, Cheers

amdee commented 1 year ago

@zaccharieramzi and maybe @mblondel, thanks for answering this question. That's what I thought if I use a Jaxopt Solver I should not worry about manually implementing the differentiating through the KKT condition as Jaxopt Solvers are differentiable out of the box. The algorithm I am trying to implement in the above-mentioned paper is summarized in the image below.image I removed all the flax dependence on the above code and left a bare-bone Jax/Jaxopt code. See the below code but I am running into the following issue

  1. CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
  2. I am not sure what I am doing wrong. I have also read the following issues, #285 and #31. Any help will be appreciated.
    
    import jax
    from jax.config import config
    import jax.numpy as jnp
    from jaxopt import Bisection
    config.update("jax_enable_x64", True)

Implement the bracketing method using jaxopt.Bisection

jax.jit

def find_nu_star(x, k=2, saturation=7.0, eps=1e-4, num_iter=100): def g(nu, x, N): return jnp.sum(jax.nn.sigmoid(x + nu)) - N

x_sorted = jnp.sort(x)[::-1]
nu_lower = -x_sorted[k-1] - saturation
nu_upper = -x_sorted[k] + saturation

# Using Bisection from jaxopt
init_params = jnp.zeros(x.shape)
bisection = Bisection(optimality_fun=g, lower=nu_lower, upper=nu_upper, maxiter=num_iter, tol=eps, check_bracket=False)
sol, _ = bisection.run(init_params, x, k)
# nu = sol.params
return sol

def calculate_y_star(x, K_value): nu_star = find_nu_star(x, K_value) y_star = jax.nn.sigmoid(x + nu_star) return y_star

data = jax.random.normal(jax.random.PRNGKey(0), shape=(3, )) n = 2 # this is k in the paper grad_result = jax.grad(calculate_y_star)(data, n) print(f"grad_result: {grad_result}")

zaccharieramzi commented 1 year ago

So this error is due to what I was explaining above, i.e. you are trying to differentiate through non differentiable parts of the code, i.e. the lower and upper parameters of Bisection. On top of that the sort function is also non differentiable.

You could use for example x_sorted = jax.lax.stop_gradient(jnp.sort(x)[::-1]).

With this I have the following error:

ValueError: Shape of cotangent input to vjp pullback function (3,) must be the same as the shape of corresponding primal input ().

which I don't understand. Ideally, you would write an issue with the bare bone minimum to reproduce this error (i.e. no need to mention which paper this is referring to) to allow for an easier processing.

amdee commented 1 year ago

@zaccharieramzi, @mblondel, I apologize for not being specific in describing my question about differentiating through the KKT condition. I am attempting to replicate an algorithm from the paper mentioned above.

$$\min_{0 < y < 1} -x^T y - H_b(y) \quad \text{subject to} \quad 1^T y = k$$ Where:

$$ Hb(y) = -\sum{i}(y_i \log y_i + (1 - y_i) \log(1 - y_i)) $$

def binary_entropy(y): """Calculate the binary entropy of a vector y.""" return -jnp.sum(y jnp.log(y) + (1 - y) jnp.log(1 - y))

def objective(y, x): """Objective function to minimize.""" return -jnp.dot(x, y) - binary_entropy(y)

def projection(y, k): """Project onto the set {y : 1^T y = k, 0 < y < 1}.""" pass

Initialize the Projected Gradient solver

solver = ProjectedGradient(fun=objective, projection=projection, maxiter=1000)

Solve the problem

result = solver.run(y_init, hyperparams_proj=k, x=x).params

zaccharieramzi commented 12 months ago

@amdee when you solve $y^\star = \argmin_y f(x, y)$, you can get $\frac{\partial y^\star}{\partial x}$. What I meant earlier is that you cannot differentiate through hyperparameters of the optimization algorithm as you were trying to do in your first example (in addition to trying to differentiate through sorting).

Indeed in your example once you implement the projection you will be able to get the gradient of result w.r.t. x. The projection can be just a clipping between 0 and 1, followed by a normalization with the sum and multiplication by k, but I am saying this without thinking too much about it. I would like to point out that answering questions like how to implement the projection is out of the scope of this project imho.