Open amdee opened 1 year ago
Hi @amdee ,
When you are using solvers directly from jaxopt, like what you are doing with Bisection
you don't need to worry about how to differentiate through the KKT conditions: you can use the solver as is and think of it as a differentiable operation.
What you can play on is whether to use implicit differentiation or not (in your case you probably want to), and how to solve the inverse jacobian (but you can also keep the defaults for this).
In other words you shouldn't need to define a custom vjp. I think in your case the problem is that you are using non differentiable operations inside lml_forward
like the sorting before even creating the optimizer.
You also cannot differentiate w.r.t. parameters of the optimizer as what you are doing here with lower and upper: basically you should be agnostic to optimizer's parameters in your differentiation.
Hope this helps, Cheers
@zaccharieramzi and maybe @mblondel, thanks for answering this question. That's what I thought if I use a Jaxopt Solver I should not worry about manually implementing the differentiating through the KKT condition as Jaxopt Solvers are differentiable out of the box. The algorithm I am trying to implement in the above-mentioned paper is summarized in the image below. I removed all the flax dependence on the above code and left a bare-bone Jax/Jaxopt code. See the below code but I am running into the following issue
import jax
from jax.config import config
import jax.numpy as jnp
from jaxopt import Bisection
config.update("jax_enable_x64", True)
def find_nu_star(x, k=2, saturation=7.0, eps=1e-4, num_iter=100): def g(nu, x, N): return jnp.sum(jax.nn.sigmoid(x + nu)) - N
x_sorted = jnp.sort(x)[::-1]
nu_lower = -x_sorted[k-1] - saturation
nu_upper = -x_sorted[k] + saturation
# Using Bisection from jaxopt
init_params = jnp.zeros(x.shape)
bisection = Bisection(optimality_fun=g, lower=nu_lower, upper=nu_upper, maxiter=num_iter, tol=eps, check_bracket=False)
sol, _ = bisection.run(init_params, x, k)
# nu = sol.params
return sol
def calculate_y_star(x, K_value): nu_star = find_nu_star(x, K_value) y_star = jax.nn.sigmoid(x + nu_star) return y_star
data = jax.random.normal(jax.random.PRNGKey(0), shape=(3, )) n = 2 # this is k in the paper grad_result = jax.grad(calculate_y_star)(data, n) print(f"grad_result: {grad_result}")
So this error is due to what I was explaining above, i.e. you are trying to differentiate through non differentiable parts of the code, i.e. the lower and upper parameters of Bisection. On top of that the sort function is also non differentiable.
You could use for example x_sorted = jax.lax.stop_gradient(jnp.sort(x)[::-1])
.
With this I have the following error:
ValueError: Shape of cotangent input to vjp pullback function (3,) must be the same as the shape of corresponding primal input ().
which I don't understand. Ideally, you would write an issue with the bare bone minimum to reproduce this error (i.e. no need to mention which paper this is referring to) to allow for an easier processing.
@zaccharieramzi, @mblondel, I apologize for not being specific in describing my question about differentiating through the KKT condition. I am attempting to replicate an algorithm from the paper mentioned above.
In summary, the algorithm addresses a convex-constrained optimization problem by calculating both the forward and backward passes. The forward pass uses a root-finding method, while the backward pass involves differentiation through the KKT condition of the convex-constrained optimization problem below. This explains the use of the argsort function, as it doesn’t affect the backward pass, according to my understanding.
Returning to my question, how can I use Jaxopt to differentiate through the KKT condition of the optimization problem below without manually calculating the KKT condition?
$$\min_{0 < y < 1} -x^T y - H_b(y) \quad \text{subject to} \quad 1^T y = k$$ Where:
$$ Hb(y) = -\sum{i}(y_i \log y_i + (1 - y_i) \log(1 - y_i)) $$
import jax.numpy as jnp
from jax import grad
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_simplex
def binary_entropy(y): """Calculate the binary entropy of a vector y.""" return -jnp.sum(y jnp.log(y) + (1 - y) jnp.log(1 - y))
def objective(y, x): """Objective function to minimize.""" return -jnp.dot(x, y) - binary_entropy(y)
def projection(y, k): """Project onto the set {y : 1^T y = k, 0 < y < 1}.""" pass
solver = ProjectedGradient(fun=objective, projection=projection, maxiter=1000)
result = solver.run(y_init, hyperparams_proj=k, x=x).params
@amdee when you solve $y^\star = \argmin_y f(x, y)$, you can get $\frac{\partial y^\star}{\partial x}$. What I meant earlier is that you cannot differentiate through hyperparameters of the optimization algorithm as you were trying to do in your first example (in addition to trying to differentiate through sorting).
Indeed in your example once you implement the projection you will be able to get the gradient of result
w.r.t. x
. The projection can be just a clipping between 0 and 1, followed by a normalization with the sum and multiplication by k
, but I am saying this without thinking too much about it. I would like to point out that answering questions like how to implement the projection is out of the scope of this project imho.
Hi,
I am currently learning how to use Jaxopt and am trying to adapt the code below to utilize its features. This code is originally from research paper and was ported from PyTorch to Jax. In the forward method, a root-finding problem is solved using the bracketing method. Meanwhile, the backward method relies on implicit differentiation through the KKT conditions.
Currently, I am only making use of Jaxopt's root-finding capabilities. I have a specific question: How can I employ Jaxopt to perform differentiation through the KKT conditions?
Please find my current implementation below:"