google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
Apache License 2.0
903 stars 62 forks source link

Error when taking gradient wrt parameters in BoxOSQP #588

Open deasmhumhna opened 3 months ago

deasmhumhna commented 3 months ago
import jaxopt

fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
matvec_A = lambda params_A, z: (z, )
solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

def test_loss(a):
    params_obj = (jnp.atleast_1d(a,),)
    l = (jnp.array([0.]),)
    u = (jnp.array([1.]),)

    init_params = solver.init_params(
        params_ineq=(l, u)
    sol =
        params_ineq=(l, u)
    zopt = sol.params.primal[-1][-1]
    return fun(zopt, params_obj)

print(test_loss(jnp.array([-0.5]))) # -1.125
print(jax.grad(test_loss)(jnp.array([1.]))) # error

Relevant traceback:

JaxStackTraceBeforeTransformation: TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <cell line: 28>()
     27 print(test_loss(jnp.array([-0.5]))) # -1.0625
---> 28 print(jax.grad(test_loss)(jnp.array([1.]))) # error

    [... skipping hidden 12 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/](https://localhost:8080/#) in solver_fun_bwd(tup, cotangent)
    235       # Compute VJPs w.r.t. args.
--> 236       vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
    237                       args=ba_args[1:], cotangent=cotangent, solve=solve)
    238       # Prepend None as the vjp for init_params.

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/](https://localhost:8080/#) in root_vjp(optimality_fun, sol, args, cotangent, solve)
     58     return optimality_fun(sol, *args)
---> 60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)
     62   # Compute the multiplication A^T u = (u^T A)^T.

    [... skipping hidden 7 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/](https://localhost:8080/#) in fun_sol(sol)
     56   def fun_sol(sol):
     57     # We close over the arguments.
---> 58     return optimality_fun(sol, *args)
     60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/](https://localhost:8080/#) in optimality_fun(params, params_obj, params_eq, params_ineq)
    352     primal_var, eq_dual_var, ineq_dual_var = params
--> 354     stationarity = grad_fun(primal_var, params_obj)
    356     if eq_dual_var is not None:

    [... skipping hidden 10 frame]

[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <lambda>(z, params_obj)
      1 import jaxopt
----> 3 fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
      4 matvec_A = lambda params_A, z: (z, )
      5 solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

Does optimality_fun/grad_fun not alter the original function fun to handle tangents properly?

I can successful get the gradient using the (Q, c) and matvec_Q paths. I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including fun.