Not specifying sinkhorn_kwargs in gromov_wasserstein raises an AttributeError; to reproduce:
from ott.core.gromov_wasserstein import gromov_wasserstein
from ott.geometry.geometry import Geometry
import jax.numpy as jnp
x = Geometry(cost_matrix=jnp.ones((10, 10)))
y = Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected
gromov_wasserstein(x, y) # raises the error below
```pytb
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_246820/4027485735.py in
6 y = Geometry(cost_matrix=jnp.ones((5, 5)))
7 gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected
----> 8 gromov_wasserstein(x, y) # raises the error below
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs)
151 raise ValueError('Unknown loss. Either pass an instance of GWLoss or '
152 f'a string among: [{",".join(GW_LOSSES.keys())}]')
--> 153 tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
154 tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
155 if tau_a != 1.0 or tau_b != 1.0:
AttributeError: 'NoneType' object has no attribute 'get'
```
Not specifying
sinkhorn_kwargs
ingromov_wasserstein
raises anAttributeError
; to reproduce:Version
0.1.17
.