ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
508 stars 79 forks source link

Potential bug in ott.geometry.segment._segment_interface #545

Closed soerenab closed 2 months ago

soerenab commented 4 months ago

Describe the bug segment._segment_interface() seems to error if the total number of points in any dataset is smaller than the max_measure_size. (I am not sure if this case should be covered by the function though.) Here is the traceback for the code snippet below:

Traceback (most recent call last):
  File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 47, in <module>
    main()
  File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 27, in main
    segment._segment_interface(
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 171, in _segment_interface
    segmented_y, segmented_weights_y = segment_point_cloud(
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 118, in segment_point_cloud
    idx = jax.lax.dynamic_slice(jnp.sort(idx), (0,), (max_measure_size,))
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 167, in dynamic_slice
    return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (80,) for operand shape (70,).

To Reproduce Steps to reproduce the behavior:

import jax.numpy as jnp
from ott.geometry import costs, pointcloud, segment

def main():

    def eval_fn(*args):
        print("eval_fn")

    dim = 10
    num_per_segment_x = jnp.array([80, 70, 50])
    num_per_segment_y = jnp.array([50, 10, 10]) # ERROR 
    # num_per_segment_y = jnp.array([60, 10, 10]) # WORKS

    x = jnp.arange(num_per_segment_x.sum() * dim).reshape(-1, dim)
    y = jnp.arange(num_per_segment_y.sum() * dim).reshape(-1, dim)

    num_segments = len(num_per_segment_x)
    max_measure_size = max(num_per_segment_x.max(), num_per_segment_y.max())

    indices_are_sorted = False
    segment_ids_x = segment_ids_y = None
    weights_x = weights_y = None
    padding_vector = None

    segment._segment_interface(
      x,
      y,
      eval_fn,
      num_segments=num_segments,
      max_measure_size=max_measure_size,
      segment_ids_x=segment_ids_x,
      segment_ids_y=segment_ids_y,
      indices_are_sorted=indices_are_sorted,
      num_per_segment_x=num_per_segment_x,
      num_per_segment_y=num_per_segment_y,
      weights_x=weights_x,
      weights_y=weights_y,
      padding_vector=padding_vector,
    )

    print("done")

Desktop (please complete the following information):

Additional context I think the error occurs as segment._segment_interface() internally calls segment_point_cloud() two times (https://github.com/ott-jax/ott/blob/main/src/ott/geometry/segment.py#L160:L180), once for x and once for y, while using the same max_measure_size that may have been computed globally using both x and y. The solution might be as simple as computing max_measure_size separately for x and y but I am actually not fully sure what segment._segment_interface() is supposed to do or whether using the same max_measure_size is somehow required, here or elsewhere.

soerenab commented 4 months ago

Here is a potential fix: https://github.com/soerenab/ott/commit/39b7209a52ba612e53f616502b19ea1f498bf3fa

michalk8 commented 2 months ago

Closed via https://github.com/ott-jax/ott/commit/2011fe4d2fa5a456983c24e0ed83e21f3dd4388a