Open nomoriel opened 1 year ago
Thanks for the bug report. I don't see any mistake in your code a priori. I tried to run your code and got a linear transpose rule error:
Traceback (most recent call last):
File "/Users/mblondel/Desktop/playground/pg_bug.py", line 122, in <module>
print(jax.grad(outer_with_proj, argnums=1)(g0, theta))
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 659, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 741, in value_and_grad_f
g = vjp_py(lax_internal._one(ans))
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/tree_util.py", line 303, in __call__
return self.fun(*args, **kw)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 2183, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/tree_util.py", line 303, in __call__
return self.fun(*args, **kw)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 146, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 751, in _custom_lin_transpose
cts_in = bwd(*res, *cts_out)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 683, in <lambda>
bwd_ = lambda *args: bwd(*args)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
u = solve(matvec, v)
File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
Ab = rmatvec(b) # A.T b
File "/Users/mblondel/Desktop/projects/jaxopt/jaxopt/_src/linear_solve.py", line 145, in <lambda>
return lambda y: transpose(y)[0]
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/api.py", line 2353, in transposed_fun
in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 253, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File "/Users/mblondel/envs/work/lib/python3.9/site-packages/jax/_src/lax/control_flow/solves.py", line 370, in _linear_solve_transpose_rule
Actually, you're trying something a bit advanced. The issue is that projection_polyhedron
calls our OSQP
solver. Therefore, computing the Jacobian of projection_polyhedron
itself requires implicit differentiation: there is a nested implicit differentiation inside the implicit differentiation you're trying to do and it seems like things don't compose correctly at the moment. I think second-order derivatives probably do not work correctly in OSQP at the moment.
CC: @Algue-Rythme @froystig
As a temporary workaround, you can set implicit_diff=False
in ProjectedGradient
. This way, there is only one implicit differentiation involved, the one of projection_polyhedron
.
Gotcha! The workaround works :) thank you very much!
Hi Thank you for this very cool repo and, sorry, I'm new to "Issues" and to jax,
I'm getting an error of
JaxStackTraceBeforeTransformation
when trying to take the derivative with aProjectedGradient
solver involved. The solver itself runs flawlessly and derivatives seem to work without the solver. I want to later useouter_with_proj
(with a slightly different objective) for implicit differentiation oftheta
.Any advice?
The code:
The error:
Thanks again!