Closed marcocuturi closed 1 month ago
@michalk8 maybe we could put that 5e-2
factor hardcoded in some "constants.py" file? we could think of other constants, e.g. max/min iterations or default threshold.
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Attention: Patch coverage is 95.00000%
with 1 line
in your changes missing coverage. Please review.
Project coverage is 87.59%. Comparing base (
aa33bd9
) to head (5a30790
).
Files with missing lines | Patch % | Lines |
---|---|---|
src/ott/experimental/mmsinkhorn.py | 66.66% | 0 Missing and 1 partial :warning: |
side comment: I tried substracting the jax.lax.stop_gradient(self.mean_cost_matrix)
to the cost_matrix
when instantiating the kernel_matrix
property to have the desired stabilization effect. This works very well in forward pass, but when differentiating and checking numerically gradients, the constant offset breaks gradient tests in sinkhorn_diff_test
. I think this could be avoided by adding back that constant, but didn't nail the right change.
LGTM, thanks a lot @marcocuturi !
Up to now,
epsilon
was scaled as a fraction of the mean cost matrix, in order to account for scale. The motivation for this was:epsilon
value without havingexp(-C/epsilon)
underflow. Likewise, largeC
values also should come with largetepsilon
.This intuition is correct when using kernel mode (which requires storing a kernel matrix) but not useful in
logsumexp
mode, which can easily account for costs that are rescaled (e.g. add constant) with no impact computationally speaking.Indeed, imagine one has an entrywise translation
delta
factor for costC
. In that case, the LSE implementation of (balanced) Sinkhorn would handle these two problems in exactly the same way:argmin <P, C> - epsilon H(P)
and
argmin <P, C + delta> - epsilon H(P)= argmin <P, C> + delta - epsilon H(P)
(because of coupling
P
total mass = 1).Yet, the current rule would result in two different
epsilon
's (specifically the second would be the first+ 0.05 * delta
, given that we setepsilon
by default to be equal to 1/20th of the mean cost.Hence, if "hardness/compute effort" is parameterized by
epsilon
,epsilon
shouldn't depend on the overall additive scale ofC
, but, instead, on its multiplicative scale. As a result, computing thestd
of the entries inC
(i.e. the order of magnitude of the centered entries inC
) is likely to be a more robust alternative.