Closed ersisimou closed 2 years ago
sorry, my bad. For the implicit differentiation I saw that there are two passes happening through the custom_linear_solve
,(which you had mentioned before in fact). So I guess you don't need the two passes from _iterations_implicit_bwd
and also I suppose custom_vjp
is seen during rederivation. But what is still unclear is why there are these initial passes through the scan_jvp
and the _while_loop_jvp
.. (Because, in fact, even for jax.jacfwd(jax.jacrev)
it goes through both scan_jvp
and _while_loop_jvp
.) If there is an explanation to that it would be very helpful. Do you think there is any possibility that going through the while_loop_jvp
should not be happening? Like, maybe there should be a stop_gradient
somewhere? Also, if the forced scan is inevitable, maybe it would make sense to have different number of iterations for Pxy than for Pxx, Pyy? many thanks
Actually, it seems that the two passes (in the case of the implicit differentiation) through the custom_linear_solve
are because of the solution of the two linear systems corresponding to the two Schur complements. Also the pass through the scan_jvp
and the _while_loop_jvp
I think are due to the fact that both scan and while are used in the fixpoint_iter
implementation. Therefore, I am under the impression that the custom_vjp
is in fact not being used during rederivation.
Although the case of unrolling maybe falls into the case that is being described in the JAX documentation (in my first comment above) because the checkpointed states ("intermediate values of f") are used in the vjp rule, this should not be the case with implicit diff, right? Because in that case only the optimal (final) state is used in the vjp rule.
Thanks in advance and apologies for the multiple comments :)
Hi Ersi, thanks a lot for your comments, and apologies for the late reply. You are right indeed to count those 2 linear solves in the solve
function in implicit_differentiation.py
. These two small linear solves are there to solve the larger system more efficiently. When derivating again (computing the Hessian) I expect the solutions to these linear systems to be differentiated again, but this would call the custom differentiation rules for linear system (as described here). Have you found anything else that's suspicious or buggy in that pipeline? I do expect some numerical instabilities to arise (implicit diff of those linear systems will be obviously impacted by bad conditioning of these systems), have you experienced those?
Hi @marcocuturi and thanks for the reply!
re: custom_vjp and re-derivation: It seems that by adding an @custom_vjp
decorator at _iterations_implicit
and then defining the custom_vjp rule as _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd)
the rederivation is possible using the custom_vjp rule of implicit diff and one does not get the lax.scan
error. A similar approach can be used for unrolling. One can define an iterations
function at sinkhorn.py
with an @custom_vjp
decorator. In that case, a bit more care is needed in the iterations_bwd
definition because you need two more pull-backs (due to the inputs-outputs of fixpoint_iter_fwd
and fixpoint_iter_bwd
). However, for the unrolling (even if the custom_vjp
is seen in rederivation) one gets again the lax.scan
error. So, I think that for unrolling it is not possible to really use the custom_vjp
in re-derivation. This could be related to the saving of itermediate states that I mentioned before. However, I would not call this a bug. I mean, the Hessian computation is correct (as shown also in the tests). It simply would be nice to have a lighter way of computing higher order derivatives.
re: coditioning of the linear system: Since in fact one needs to tune the ridge parameters in order to ensure that the linear system is well-conditioned in implicit diff (both for gradient and higher order), I am using for the computation of the higher order derivatives the unrolling to be on the safe side :) .
I might make a PR for the re-derivation with implicit diff once I look into it more carefully. I do think it would be nice to have a more computationally efficient way to compute higher order derivatives :)
That would be an excellent contribution for sure. I will leave the issue open in case there's some news on your side.
Closing this for now, can reopen later.
Hi, @ersisimou, @marcocuturi, @michalk8 We looked into this bad conditioning issue when computing higher order derivatives of entropic OT distance. We derived the analytical expression of Hessian and found out it evolves the pseudo inverse of [diag(P1), P; P^T, diag(P^T1)]. The numerical instability is due to the large condition number of that matrix. The stable and proper way of conditioning is to perform truncated SVD on this matrix.
We test it with the point-cloud datasets sampled from the uniform distribution in unit square with N = 10, 20, 120, 1600 and eps=0.005. Both unroll and implicit methods fail in many test runs either due to ill-conditioning or memory limitations. Our analytical expression with truncated SVD not only pass all test runs but also the fastest and the most accurate among of all.
Our result is posted on arxiv: http://arxiv.org/abs/2407.02015. I hope this work will draw your attention.
Hello,
It seems that the
custom_vjp
(defined with either_iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd)
insinkhorn.py
for the implicit differentiation orfixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd)
infixed_point_loop
for the unrolling) does not seem to be used for the computation of higher order derivatives. I believe that this is also related to the error discussed previously in this issue. This is not evident with the computation ofjax.hessian
(which is implemented asjax.jacfwd(jax.jacrev)
, as the second derivation uses forward mode automatic differentiation, which is compatible with thejax.lax.while_loop
. Therefore, no error is raised even if thecustom_vjp
is ignored the second time.I think that this can be seen by adding a breakpoint in the
_while_loop_jvp
of JAX'scontrol_flow.py
. For instance, computation ofjax.hessian
passes through_while_loop_jvp
first, then one time throughfixpoint_iter
for implicit diff (orfixpoint_iter_fwd
for unrolling) and then one time through_iterations_implicit_bwd
for implicit diff (orfixpoint_iter_bwd
for unrolling). In the case where ajax.lax.scan
is forced by settingmin_iterations
equal tomax_iterations
(and both Jacobians are computed with reverse mode), instead of the initial pass through_while_loop_jvp
, one gets in the end a pass through_scan_transpose
of JAX'scontrol_flow.py
.In either case, I think that if the
custom_vjp
was not ignored during rederivation, one should get two passes through_iterations_implicit_bwd
(equivalentlyfixpoint_iter_bwd
), right?I am not sure how this could be fixed. The only relevant info that I could find in JAX's documentation for custom_vjp was this:
"Notice that f_jvp calls f to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of f in our rule and also have the rule apply in all orders of higher-order differentiation.)"
I hope I am not missing something. If there is a solution to this, it would be great, as what I need to compute is
jax.jacrev(jax.grad)
for the sinkhorn divergence and the forced scan option causes a significant computational overhead, especially for the two autocorrelation terms Pxx, Pyy.Many thanks!