ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

sink div for LR #568

Closed marcocuturi closed 3 months ago

marcocuturi commented 3 months ago

simple patch for issue described in https://github.com/ott-jax/ott/issues/485, when user wishes to compute a sinkhorn divergence using LR sinkhorn as a primitive.

marcocuturi commented 3 months ago

Checking if this seems ok, if it does, will write some tests to check non-negativitiy, in accordance with paper

codecov[bot] commented 3 months ago

Codecov Report

Attention: Patch coverage is 90.90909% with 3 lines in your changes missing coverage. Please review.

Project coverage is 88.41%. Comparing base (478f034) to head (b66005d). Report is 26 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/tools/sinkhorn_divergence.py 90.00% 2 Missing and 1 partial :warning:
Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/ott-jax/ott/pull/568/graphs/tree.svg?width=650&height=150&src=pr&token=14PUIHGLV9&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)](https://app.codecov.io/gh/ott-jax/ott/pull/568?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) ```diff @@ Coverage Diff @@ ## main #568 +/- ## ========================================== + Coverage 88.36% 88.41% +0.05% ========================================== Files 72 72 Lines 7699 7716 +17 Branches 1102 1107 +5 ========================================== + Hits 6803 6822 +19 + Misses 745 743 -2 Partials 151 151 ``` | [Files with missing lines](https://app.codecov.io/gh/ott-jax/ott/pull/568?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) | Coverage Δ | | |---|---|---| | [src/ott/math/utils.py](https://app.codecov.io/gh/ott-jax/ott/pull/568?src=pr&el=tree&filepath=src%2Fott%2Fmath%2Futils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9tYXRoL3V0aWxzLnB5) | `91.30% <100.00%> (+0.19%)` | :arrow_up: | | [src/ott/solvers/linear/sinkhorn\_lr.py](https://app.codecov.io/gh/ott-jax/ott/pull/568?src=pr&el=tree&filepath=src%2Fott%2Fsolvers%2Flinear%2Fsinkhorn_lr.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9zb2x2ZXJzL2xpbmVhci9zaW5raG9ybl9sci5weQ==) | `98.98% <100.00%> (+0.32%)` | :arrow_up: | | [src/ott/tools/progot.py](https://app.codecov.io/gh/ott-jax/ott/pull/568?src=pr&el=tree&filepath=src%2Fott%2Ftools%2Fprogot.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC90b29scy9wcm9nb3QucHk=) | `29.07% <ø> (ø)` | | | [src/ott/tools/sinkhorn\_divergence.py](https://app.codecov.io/gh/ott-jax/ott/pull/568?src=pr&el=tree&filepath=src%2Fott%2Ftools%2Fsinkhorn_divergence.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC90b29scy9zaW5raG9ybl9kaXZlcmdlbmNlLnB5) | `91.66% <90.00%> (+0.62%)` | :arrow_up: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/ott-jax/ott/pull/568/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)
michalk8 commented 3 months ago

@marcocuturi I've modified one test, but differentiating w.r.t. the divergence is not yet working (problems when unrolling):

src/ott/math/fixed_point_loop.py:226: in fixpoint_iter_bwd
    _, g_state, g_constants = jax.lax.while_loop(
src/ott/math/fixed_point_loop.py:209: in unrolled_body_fn
    _, pullback = jax.vjp(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

iteration = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
constants = (<ott.problems.linear.linear_problem.LinearProblem object at 0x3436e2d10>, <ott.solvers.linear.sinkhorn_lr.LRSinkhorn object at 0x3436e2d70>)
state = LRSinkhornState(q=Traced<ShapedArray(float32[13,2])>with<JVPTrace(level=5/0)> with
  primal = Traced<ShapedArray(float...), None)
    recipe = LambdaBinding(), crossed_threshold=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>)

    def unrolled_body_fn_no_errors(iteration, constants, state):
      compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool)

      def one_iteration(iteration_state, compute_error):
        iteration, state = iteration_state
        state = body_fn(iteration, constants, state, compute_error)
        iteration += 1
        return (iteration, state), None

>     iteration_state, _ = jax.lax.scan(
          one_iteration, (iteration, state), compute_error_flags
      )
E     jax._src.interpreters.ad.CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try pa
review-notebook-app[bot] commented 3 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

marcocuturi commented 3 months ago

Thanks!

The bug came from computing jax.grad(lambda x : div(x).divergence) rather than defining a proper closure, with a function that outputs the desired divergence directly, before taking gradient.

I think the syntax above computes the entire differentiation graph for all elements outputted in a LRSinkhornOutput (notably the q,r,g) and then only gathers the divergence gradient? If that's the case, it was expected we would run into trouble as we haven't properly defined differentiation rules for solutions in LRSinkhorn.

Computing the gradient of a LRSinkhornOutput.reg_ot_cost w.r.t. input location x was working before this PR, so defining a proper function outputting divergence works.

As a result, I have reinstated the very primitive differentiability test.