google-research / ott

Apache License 2.0
213 stars 18 forks source link

GW default sinkhorn kwargs raises AttributeError #17

Closed michalk8 closed 2 years ago

michalk8 commented 2 years ago

Not specifying sinkhorn_kwargs in gromov_wasserstein raises an AttributeError; to reproduce:

from ott.core.gromov_wasserstein import gromov_wasserstein
from ott.geometry.geometry import Geometry
import jax.numpy as jnp

x = Geometry(cost_matrix=jnp.ones((10, 10)))
y = Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={})  # works as expected
gromov_wasserstein(x, y)  # raises the error below
```pytb --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /tmp/ipykernel_246820/4027485735.py in 6 y = Geometry(cost_matrix=jnp.ones((5, 5))) 7 gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected ----> 8 gromov_wasserstein(x, y) # raises the error below ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs) 151 raise ValueError('Unknown loss. Either pass an instance of GWLoss or ' 152 f'a string among: [{",".join(GW_LOSSES.keys())}]') --> 153 tau_a = sinkhorn_kwargs.get('tau_a', 1.0) 154 tau_b = sinkhorn_kwargs.get('tau_b', 1.0) 155 if tau_a != 1.0 or tau_b != 1.0: AttributeError: 'NoneType' object has no attribute 'get' ```

Version 0.1.17.

marcocuturi commented 2 years ago

Thanks! indeed, this was not properly tested. Just wrote pull request #19

marcocuturi commented 2 years ago

Thanks again Michal for noticing this one!