ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

Possible bug in `fused_penalty` #546

Closed gjhuizing closed 4 months ago

gjhuizing commented 5 months ago

Hello,

I would expect these two blocks of code to provide the same answer:

geom_ab = PointCloud(a, b, scale_cost=1/fused_penalty)
problem = QuadraticProblem(geom_xx, geom_yy, geom_ab, fused_penalty=1.0)
solver = GromovWasserstein(epsilon=epsilon)
solver(problem).reg_gw_cost
geom_ab = PointCloud(a, b)
problem = QuadraticProblem(geom_xx, geom_yy, geom_ab, fused_penalty=fused_penalty)
solver = GromovWasserstein(epsilon=epsilon)
solver(problem).reg_gw_cost

However, they output different values, and the first block seems to be equivalent to

geom_ab = PointCloud(a, b)
problem = QuadraticProblem(geom_xx, geom_yy, geom_ab, fused_penalty=np.sqrt(fused_penalty))
solver = GromovWasserstein(epsilon=epsilon)
solver(problem).reg_gw_cost

I thought the FGW solver was <~C_quad,P> + fused_penalty*<C_lin,P> - epsilon*E(P). Is that not the case?

Thanks a lot!

GJ

gjhuizing commented 5 months ago

Is it possible this line counts the fused penalty twice, leading to fused_penalty being squared?

https://github.com/ott-jax/ott/blob/84a1f1b0b0ccac6eddf22503dd82ced64cf81b0f/src/ott/problems/quadratic/quadratic_problem.py#L415

michalk8 commented 4 months ago

Hi @gjhuizing , thanks, will look into this!

michalk8 commented 4 months ago

Is it possible this line counts the fused penalty twice, leading to fused_penalty being squared?

https://github.com/ott-jax/ott/blob/84a1f1b0b0ccac6eddf22503dd82ced64cf81b0f/src/ott/problems/quadratic/quadratic_problem.py#L415

I just checked this by running

scaled_pc = PointCloud(a, b, scale_cost=1.0/fp)
lr_pc = PointCloud(a, b).to_LRCGeometry(scale=fp)
jnp.abs(scaled_pc.cost_matrix - lr_pc.cost_matrix).max()
# Array(4.41424291e-14, dtype=float64)

the above line seems to be correct, so the problem is somewhere within the GW solver/problem itself.

michalk8 commented 4 months ago

Thanks for noticing this @gjhuizing , fixed in #558!

gjhuizing commented 4 months ago

Great, thanks for fixing it! I think this may affect Moscot downstream, it may be worth checking if it changes the behavior for the default parameter of alpha!

marcocuturi commented 4 months ago

thanks to both of you!