ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

automatic scaling of `epsilon` using `std` instead of `mean` by default. #578

Closed marcocuturi closed 1 month ago

marcocuturi commented 1 month ago

Up to now, epsilon was scaled as a fraction of the mean cost matrix, in order to account for scale. The motivation for this was:

This intuition is correct when using kernel mode (which requires storing a kernel matrix) but not useful in logsumexp mode, which can easily account for costs that are rescaled (e.g. add constant) with no impact computationally speaking.

Indeed, imagine one has an entrywise translation delta factor for cost C. In that case, the LSE implementation of (balanced) Sinkhorn would handle these two problems in exactly the same way:

argmin <P, C> - epsilon H(P)

and

argmin <P, C + delta> - epsilon H(P)= argmin <P, C> + delta - epsilon H(P)

(because of coupling P total mass = 1).

Yet, the current rule would result in two different epsilon's (specifically the second would be the first + 0.05 * delta, given that we set epsilon by default to be equal to 1/20th of the mean cost.

Hence, if "hardness/compute effort" is parameterized by epsilon, epsilon shouldn't depend on the overall additive scale of C, but, instead, on its multiplicative scale. As a result, computing the std of the entries in C (i.e. the order of magnitude of the centered entries in C) is likely to be a more robust alternative.

marcocuturi commented 1 month ago

@michalk8 maybe we could put that 5e-2 factor hardcoded in some "constants.py" file? we could think of other constants, e.g. max/min iterations or default threshold.

review-notebook-app[bot] commented 1 month ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 95.00000% with 1 line in your changes missing coverage. Please review.

Project coverage is 87.59%. Comparing base (aa33bd9) to head (5a30790).

Files with missing lines Patch % Lines
src/ott/experimental/mmsinkhorn.py 66.66% 0 Missing and 1 partial :warning:
Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/ott-jax/ott/pull/578/graphs/tree.svg?width=650&height=150&src=pr&token=14PUIHGLV9&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) ```diff @@ Coverage Diff @@ ## main #578 +/- ## ========================================== - Coverage 87.83% 87.59% -0.24% ========================================== Files 73 73 Lines 7826 7838 +12 Branches 1127 1132 +5 ========================================== - Hits 6874 6866 -8 - Misses 799 818 +19 - Partials 153 154 +1 ``` | [Files with missing lines](https://app.codecov.io/gh/ott-jax/ott/pull/578?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) | Coverage Δ | | |---|---|---| | [src/ott/geometry/epsilon\_scheduler.py](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&filepath=src%2Fott%2Fgeometry%2Fepsilon_scheduler.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9nZW9tZXRyeS9lcHNpbG9uX3NjaGVkdWxlci5weQ==) | `94.44% <100.00%> (+0.15%)` | :arrow_up: | | [src/ott/geometry/geometry.py](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&filepath=src%2Fott%2Fgeometry%2Fgeometry.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9nZW9tZXRyeS9nZW9tZXRyeS5weQ==) | `92.92% <100.00%> (+0.23%)` | :arrow_up: | | [src/ott/solvers/linear/sinkhorn.py](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&filepath=src%2Fott%2Fsolvers%2Flinear%2Fsinkhorn.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9zb2x2ZXJzL2xpbmVhci9zaW5raG9ybi5weQ==) | `99.38% <ø> (ø)` | | | [src/ott/tools/sinkhorn\_divergence.py](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&filepath=src%2Fott%2Ftools%2Fsinkhorn_divergence.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC90b29scy9zaW5raG9ybl9kaXZlcmdlbmNlLnB5) | `91.86% <ø> (ø)` | | | [src/ott/experimental/mmsinkhorn.py](https://app.codecov.io/gh/ott-jax/ott/pull/578?src=pr&el=tree&filepath=src%2Fott%2Fexperimental%2Fmmsinkhorn.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9leHBlcmltZW50YWwvbW1zaW5raG9ybi5weQ==) | `91.56% <66.66%> (-0.56%)` | :arrow_down: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/ott-jax/ott/pull/578/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)
marcocuturi commented 1 month ago

side comment: I tried substracting the jax.lax.stop_gradient(self.mean_cost_matrix) to the cost_matrix when instantiating the kernel_matrix property to have the desired stabilization effect. This works very well in forward pass, but when differentiating and checking numerically gradients, the constant offset breaks gradient tests in sinkhorn_diff_test. I think this could be avoided by adding back that constant, but didn't nail the right change.

michalk8 commented 1 month ago

LGTM, thanks a lot @marcocuturi !