google-research / ott

Apache License 2.0
213 stars 18 forks source link

Creating a PointCloud using np.ndarrays leads to exceptions in the sinkhorn algorithm #13

Closed GoogleGeoff closed 2 years ago

GoogleGeoff commented 3 years ago

If I do the following:

x1 = np.zeros((2048, 64)) x2 = np.zeros((120, 64)) geom = pointcloud.PointCloud(x1, x2, epsilon=1.e-3) output = sinkhorn.sinkhorn(geom) print(output.reg_ot_cost)

I get the exception below:

...

File "ott/geometry/geometry.py", line 311, in _center return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix File "jax/_src/numpy/lax_numpy.py", line 6589, in deferring_binary_op return binary_op(self, other) File "bug_demo/ott_bug.runfiles/google3/third_party/py/jax/_src/numpy/lax_numpy.py", line 679, in fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.name, x1, x2)) TypeError: sub got incompatible shapes for broadcasting: (0, 0), (2048, 120).

If, instead, I use jnp.zeros above, everything works as expected.

marcocuturi commented 2 years ago

Thanks Geoff for pointing this out. It seems the issue was caused by the way the shape property of a geometry object was generated. In the CL you proposed this has been taken care of, more specifically in

https://github.com/google-research/ott/commit/04d089f45154cbc50a13c994bcf07e5e9e78bb5f

Thanks!