Describe the bugsegment._segment_interface() seems to error if the total number of points in any dataset is smaller than the max_measure_size. (I am not sure if this case should be covered by the function though.) Here is the traceback for the code snippet below:
Traceback (most recent call last):
File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 47, in <module>
main()
File "/p/project/dynadis/soeren.becker/repos/inverse_cot/drafts/test_segment_sinkhorn.py", line 27, in main
segment._segment_interface(
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 171, in _segment_interface
segmented_y, segmented_weights_y = segment_point_cloud(
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/ott/geometry/segment.py", line 118, in segment_point_cloud
idx = jax.lax.dynamic_slice(jnp.sort(idx), (0,), (max_measure_size,))
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 167, in dynamic_slice
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
return primitive.impl(*tracers, **params)
File "/p/project/dynadis/soeren.becker/envs/env_icot/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (80,) for operand shape (70,).
Desktop (please complete the following information):
OS: Linux
Additional context
I think the error occurs as segment._segment_interface() internally calls segment_point_cloud() two times (https://github.com/ott-jax/ott/blob/main/src/ott/geometry/segment.py#L160:L180), once for x and once for y, while using the same max_measure_size that may have been computed globally using both x and y. The solution might be as simple as computing max_measure_size separately for x and y but I am actually not fully sure what segment._segment_interface() is supposed to do or whether using the same max_measure_size is somehow required, here or elsewhere.
Describe the bug
segment._segment_interface()
seems to error if the total number of points in any dataset is smaller than the max_measure_size. (I am not sure if this case should be covered by the function though.) Here is the traceback for the code snippet below:To Reproduce Steps to reproduce the behavior:
Desktop (please complete the following information):
Additional context I think the error occurs as
segment._segment_interface()
internally callssegment_point_cloud()
two times (https://github.com/ott-jax/ott/blob/main/src/ott/geometry/segment.py#L160:L180), once forx
and once fory
, while using the samemax_measure_size
that may have been computed globally using bothx
andy
. The solution might be as simple as computingmax_measure_size
separately forx
andy
but I am actually not fully sure whatsegment._segment_interface()
is supposed to do or whether using the samemax_measure_size
is somehow required, here or elsewhere.