ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
521 stars 80 forks source link

AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} #485

Closed Farbodch closed 2 months ago

Farbodch commented 9 months ago

Describe the bug When

ot = sinkhorn_divergence.sinkhorn_divergence(
    geom,
    x=geom.x,
    y=geom.y,
    static_b=True,
    sinkhorn_kwargs={"rank":10,"initializer":'random'})
return ot.divergence, ot

is called (through jax.jit(jax.value_and_grad(...))), the expected result is for the low-rank sinkhorn (LRSinkhorn) solver to be used. However, an AttributeError is thrown instead:

    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

AttributeError: 'LRSinkhornOutput' object has no attribute 'f'

Full Error Output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1262], line 109
--> 109 geom_grad = get_geom(my_source, my_target)

Cell In[1262], line 29, in get_geom(source, target)
     27 def get_geom(source, target):
     28     geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
---> 29     (cost, ot), geom_g = cost_fn_vg(geom)
     30     assert ot.converged

    [... skipping hidden 20 frame]

Cell In[1262], line 9, in sink_div(geom)
      4 def sink_div(geom):
      5     """Return the Sinkhorn divergence cost and OT output given a geometry.
      6     Since y is fixed, we can use static_b=True to avoid computing
      7     the OT(b, b) term."""
----> 9     ot = sinkhorn_divergence.sinkhorn_divergence(
     10         geom,
     11         x=geom.x,
     12         y=geom.y,
     13         static_b=True,
     14         sinkhorn_kwargs={"rank":10,"initializer":'random'})
     15     return ot.divergence, ot

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:103, in sinkhorn_divergence(geom, a, b, sinkhorn_kwargs, static_b, share_epsilon, symmetric_sinkhorn, *args, **kwargs)
    101 a = jnp.ones(num_a) / num_a if a is None else a
    102 b = jnp.ones(num_b) / num_b if b is None else b
--> 103 return _sinkhorn_divergence(
    104     geom_xy,
    105     geom_x,
    106     geom_y,
    107     a=a,
    108     b=b,
    109     symmetric_sinkhorn=symmetric_sinkhorn,
    110     **sinkhorn_kwargs
    111 )

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in _sinkhorn_divergence(geometry_xy, geometry_xx, geometry_yy, a, b, symmetric_sinkhorn, **kwargs)
    178 div = (
    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

File ~/Miniconda3-py311_23.5.2-0-Linux-x86_64/envs/env1/lib/python3.11/site-packages/ott/tools/sinkhorn_divergence.py:184, in <genexpr>(.0)
    178 div = (
    179     out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
    180     0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b)) ** 2
    181 )
    182 out = (out_xy, out_xx, out_yy)
    183 return SinkhornDivergenceOutput(
--> 184     div, tuple([s.f, s.g] for s in out),
    185     (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out),
    186     tuple(s.converged for s in out), a, b
    187 )

AttributeError: 'LRSinkhornOutput' object has no attribute 'f'

To Reproduce Relevant code snippet used (to reproduce the behavior):

def sink_div(geom):
    """Return the Sinkhorn divergence cost and OT output given a geometry.
    Since y is fixed, we can use static_b=True to avoid computing
    the OT(b, b) term."""

    ot = sinkhorn_divergence.sinkhorn_divergence(
        geom,
        x=geom.x,
        y=geom.y,
        static_b=True,
        sinkhorn_kwargs={"rank":10,"initializer":'random'})
    return ot.divergence, ot

cost_fn_vg = jax.jit(jax.value_and_grad(sink_div, has_aux=True))

get_geom(source, target):
    geom = pointcloud.PointCloud(source, target, epsilon=epsilon)
    (cost, ot), geom_g = cost_fn_vg(geom)
    assert ot.converged

    return geom_g.x

Additional information (please complete the following information): Overall script works as expected when sink_div (as stated above) is replaced with direct lowrank sinkhorn solver (sink_lr_cost below):

def sink_lr_cost(geom):
    """Return the OT cost and OT output given a geometry"""
    ot = sinkhorn_lr.LRSinkhorn(rank=10, initializer='random')(linear_problem.LinearProblem(geom))
    return ot.reg_ot_cost, ot

cost_fn_vg = jax.jit(jax.value_and_grad(sink_lr_cost, has_aux=True))

System/Environment information

marcocuturi commented 9 months ago

Hi @Farbodch sorry about this, we did not expect this usage, but this is indeed very valid, specially since we explored this here :https://proceedings.neurips.cc/paper_files/paper/2022/hash/2d69e771d9f274f7c624198ea74f5b98-Abstract-Conference.html

essentially this is just an API bug, and should work all right, it's just that we shouldn't try to pull the .f and .g potentials from LR sinkhorn output in that case.

How urgent is this? If you need this for ICML let us know, we can come with a slightly dumb patch.

Farbodch commented 9 months ago

Hi @marcocuturi

Thank you for the reply! It’s not super urgent/it won’t make it to ICML, but I would greatly appreciate any updates!

michalk8 commented 2 months ago

Hi @Farbodch , it's finally implemented!

closed via #568