ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

custom_vjp seems to not be used during computation of higher order derivatives of Sinkhorn #15

Closed ersisimou closed 2 years ago

ersisimou commented 2 years ago

Hello,

It seems that the custom_vjp (defined with either _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) in sinkhorn.py for the implicit differentiation or fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd) in fixed_point_loop for the unrolling) does not seem to be used for the computation of higher order derivatives. I believe that this is also related to the error discussed previously in this issue. This is not evident with the computation of jax.hessian (which is implemented as jax.jacfwd(jax.jacrev), as the second derivation uses forward mode automatic differentiation, which is compatible with the jax.lax.while_loop. Therefore, no error is raised even if the custom_vjp is ignored the second time.

I think that this can be seen by adding a breakpoint in the _while_loop_jvp of JAX's control_flow.py. For instance, computation of jax.hessian passes through _while_loop_jvp first, then one time throughfixpoint_iter for implicit diff (or fixpoint_iter_fwd for unrolling) and then one time through _iterations_implicit_bwd for implicit diff (or fixpoint_iter_bwd for unrolling). In the case where a jax.lax.scan is forced by setting min_iterations equal to max_iterations (and both Jacobians are computed with reverse mode), instead of the initial pass through _while_loop_jvp, one gets in the end a pass through _scan_transpose of JAX's control_flow.py.

In either case, I think that if the custom_vjp was not ignored during rederivation, one should get two passes through _iterations_implicit_bwd (equivalently fixpoint_iter_bwd), right?

I am not sure how this could be fixed. The only relevant info that I could find in JAX's documentation for custom_vjp was this:

"Notice that f_jvp calls f to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of f in our rule and also have the rule apply in all orders of higher-order differentiation.)"

I hope I am not missing something. If there is a solution to this, it would be great, as what I need to compute is jax.jacrev(jax.grad) for the sinkhorn divergence and the forced scan option causes a significant computational overhead, especially for the two autocorrelation terms Pxx, Pyy.

Many thanks!

ersisimou commented 2 years ago

sorry, my bad. For the implicit differentiation I saw that there are two passes happening through the custom_linear_solve,(which you had mentioned before in fact). So I guess you don't need the two passes from _iterations_implicit_bwd and also I suppose custom_vjp is seen during rederivation. But what is still unclear is why there are these initial passes through the scan_jvp and the _while_loop_jvp.. (Because, in fact, even for jax.jacfwd(jax.jacrev) it goes through both scan_jvp and _while_loop_jvp.) If there is an explanation to that it would be very helpful. Do you think there is any possibility that going through the while_loop_jvp should not be happening? Like, maybe there should be a stop_gradient somewhere? Also, if the forced scan is inevitable, maybe it would make sense to have different number of iterations for Pxy than for Pxx, Pyy? many thanks

ersisimou commented 2 years ago

Actually, it seems that the two passes (in the case of the implicit differentiation) through the custom_linear_solve are because of the solution of the two linear systems corresponding to the two Schur complements. Also the pass through the scan_jvp and the _while_loop_jvp I think are due to the fact that both scan and while are used in the fixpoint_iter implementation. Therefore, I am under the impression that the custom_vjp is in fact not being used during rederivation.

Although the case of unrolling maybe falls into the case that is being described in the JAX documentation (in my first comment above) because the checkpointed states ("intermediate values of f") are used in the vjp rule, this should not be the case with implicit diff, right? Because in that case only the optimal (final) state is used in the vjp rule.

Thanks in advance and apologies for the multiple comments :)

marcocuturi commented 2 years ago

Hi Ersi, thanks a lot for your comments, and apologies for the late reply. You are right indeed to count those 2 linear solves in the solve function in implicit_differentiation.py. These two small linear solves are there to solve the larger system more efficiently. When derivating again (computing the Hessian) I expect the solutions to these linear systems to be differentiated again, but this would call the custom differentiation rules for linear system (as described here). Have you found anything else that's suspicious or buggy in that pipeline? I do expect some numerical instabilities to arise (implicit diff of those linear systems will be obviously impacted by bad conditioning of these systems), have you experienced those?

ersisimou commented 2 years ago

Hi @marcocuturi and thanks for the reply!

re: custom_vjp and re-derivation: It seems that by adding an @custom_vjp decorator at _iterations_implicit and then defining the custom_vjp rule as _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) the rederivation is possible using the custom_vjp rule of implicit diff and one does not get the lax.scan error. A similar approach can be used for unrolling. One can define an iterations function at sinkhorn.py with an @custom_vjp decorator. In that case, a bit more care is needed in the iterations_bwd definition because you need two more pull-backs (due to the inputs-outputs of fixpoint_iter_fwd and fixpoint_iter_bwd). However, for the unrolling (even if the custom_vjp is seen in rederivation) one gets again the lax.scan error. So, I think that for unrolling it is not possible to really use the custom_vjp in re-derivation. This could be related to the saving of itermediate states that I mentioned before. However, I would not call this a bug. I mean, the Hessian computation is correct (as shown also in the tests). It simply would be nice to have a lighter way of computing higher order derivatives.

re: coditioning of the linear system: Since in fact one needs to tune the ridge parameters in order to ensure that the linear system is well-conditioned in implicit diff (both for gradient and higher order), I am using for the computation of the higher order derivatives the unrolling to be on the safe side :) .

I might make a PR for the re-derivation with implicit diff once I look into it more carefully. I do think it would be nice to have a more computationally efficient way to compute higher order derivatives :)

marcocuturi commented 2 years ago

That would be an excellent contribution for sure. I will leave the issue open in case there's some news on your side.

marcocuturi commented 2 years ago

Closing this for now, can reopen later.

yexf308 commented 3 weeks ago

Hi, @ersisimou, @marcocuturi, @michalk8 We looked into this bad conditioning issue when computing higher order derivatives of entropic OT distance. We derived the analytical expression of Hessian and found out it evolves the pseudo inverse of [diag(P1), P; P^T, diag(P^T1)]. The numerical instability is due to the large condition number of that matrix. The stable and proper way of conditioning is to perform truncated SVD on this matrix.

We test it with the point-cloud datasets sampled from the uniform distribution in unit square with N = 10, 20, 120, 1600 and eps=0.005. Both unroll and implicit methods fail in many test runs either due to ill-conditioning or memory limitations. Our analytical expression with truncated SVD not only pass all test runs but also the fastest and the most accurate among of all.

Our result is posted on arxiv: http://arxiv.org/abs/2407.02015. I hope this work will draw your attention.

Screenshot 2024-10-21 at 11 16 42 PM