Closed marcocuturi closed 3 months ago
Checking if this seems ok, if it does, will write some tests to check non-negativitiy, in accordance with paper
Attention: Patch coverage is 90.90909%
with 3 lines
in your changes missing coverage. Please review.
Project coverage is 88.41%. Comparing base (
478f034
) to head (b66005d
). Report is 26 commits behind head on main.
Files with missing lines | Patch % | Lines |
---|---|---|
src/ott/tools/sinkhorn_divergence.py | 90.00% | 2 Missing and 1 partial :warning: |
@marcocuturi I've modified one test, but differentiating w.r.t. the divergence is not yet working (problems when unrolling):
src/ott/math/fixed_point_loop.py:226: in fixpoint_iter_bwd
_, g_state, g_constants = jax.lax.while_loop(
src/ott/math/fixed_point_loop.py:209: in unrolled_body_fn
_, pullback = jax.vjp(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
iteration = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
constants = (<ott.problems.linear.linear_problem.LinearProblem object at 0x3436e2d10>, <ott.solvers.linear.sinkhorn_lr.LRSinkhorn object at 0x3436e2d70>)
state = LRSinkhornState(q=Traced<ShapedArray(float32[13,2])>with<JVPTrace(level=5/0)> with
primal = Traced<ShapedArray(float...), None)
recipe = LambdaBinding(), crossed_threshold=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>)
def unrolled_body_fn_no_errors(iteration, constants, state):
compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool)
def one_iteration(iteration_state, compute_error):
iteration, state = iteration_state
state = body_fn(iteration, constants, state, compute_error)
iteration += 1
return (iteration, state), None
> iteration_state, _ = jax.lax.scan(
one_iteration, (iteration, state), compute_error_flags
)
E jax._src.interpreters.ad.CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try pa
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Thanks!
The bug came from computing jax.grad(lambda x : div(x).divergence)
rather than defining a proper closure, with a function that outputs the desired divergence directly, before taking gradient.
I think the syntax above computes the entire differentiation graph for all elements outputted in a LRSinkhornOutput
(notably the q,r,g
) and then only gathers the divergence
gradient? If that's the case, it was expected we would run into trouble as we haven't properly defined differentiation rules for solutions in LRSinkhorn
.
Computing the gradient of a LRSinkhornOutput.reg_ot_cost
w.r.t. input location x
was working before this PR, so defining a proper function outputting divergence
works.
As a result, I have reinstated the very primitive differentiability test.
simple patch for issue described in https://github.com/ott-jax/ott/issues/485, when user wishes to compute a sinkhorn divergence using LR sinkhorn as a primitive.