ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
508 stars 79 forks source link

Unbalanced FGW doesn't converge when margins are provided #519

Open selmanozleyen opened 5 months ago

selmanozleyen commented 5 months ago

Describe the bug For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677

Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after https://github.com/ott-jax/ott/commit/41906a2a1ade19aa154189fabd7c159a160c9bf3

To Reproduce

import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve

# Generating random data for x and y
x = np.random.rand(96, 2)  # 96 points in 2D
y = np.random.rand(96, 2)  # Another 96 points in 2D

# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)

# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])

# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
      fused_penalty=1.0, epsilon=1.0, a=a, b=b)
michalk8 commented 5 months ago

Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where marginal_1 is an array of all 0s (leads to a transport mass of 0), and later to the rescaling factor to be NaN. I will take a look whether there's more numerically stable way of computing this, however simply using

a = jnp.ones(x.shape[0]) / x.shape[0]
b = jnp.ones(y.shape[0]) / y.shape[0]

solves to numerical precision issues.

selmanozleyen commented 4 months ago

@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1

a = np.ones(x.shape[0])*2
a[0:4] = 1
b = np.ones(y.shape[0])*2
b[0:4] = 1
# or 
a = np.ones(x.shape[0])*2
b = np.ones(y.shape[0])*2
marcocuturi commented 2 months ago

Thanks @selmanozleyen . I think what's happening here is a problem of scales. Although it may seem dividing/multiplying a/b by a constant should have no bearing on the optimization, in the case of entropic GW this is likely not the case because of the interplay with other parameters (notably epsilon but also more generally the scale of the cost matrix, since the unbalanced problem adds a KL term.

Tangentially related: I think the converged flag in GW was bugged, as discussed in https://github.com/ott-jax/ott/pull/566