MUCDK / ott

Optimal Transport tools implemented with the JAX framework, to get auto-diff, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
0 stars 2 forks source link

OTDataSet indexing error when number of source and target samples differ #3

Open soerenab opened 8 months ago

soerenab commented 8 months ago

Describe the bug The OTDataSet seems to assume that we have equally many samples from source and target. If this is not the case, in particular if there are more source than target samples, this will lead to an error:

File "/p/project/dynadis/soeren.becker/repos/ott/src/ott/neural/flow_models/genot.py", line 400, in __call__
    for batch in tqdm(train_loader):
  File "/p/project/dynadis/soeren.becker/envs/env_genot_spim2/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/p/project/dynadis/soeren.becker/envs/env_genot_spim2/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/p/project/dynadis/soeren.becker/envs/env_genot_spim2/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/dynadis/soeren.becker/envs/env_genot_spim2/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/dynadis/soeren.becker/envs/env_genot_spim2/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/p/project/dynadis/soeren.becker/repos/ott/src/ott/neural/data/dataloaders.py", line 77, in __getitem__
    self.target_lin[idx] if self.target_lin is not None else [],
    ~~~~~~~~~~~~~~~^^^^^
IndexError: index 20615 is out of bounds for axis 0 with size 16361

To Reproduce Initialize dataloader with fewer target samples than source samples and train genot with it.

Expected behavior No indexing error.

Additional context The error arises as you set the len of the dataset only based on the number of source samples: https://github.com/MUCDK/ott/blob/draft/neural_base_solver/src/ott/neural/data/dataloaders.py#L88:L90

Aside from this: Using a single index for both source and target as here https://github.com/MUCDK/ott/blob/draft/neural_base_solver/src/ott/neural/data/dataloaders.py#L70:L86 means that you always load "pairs", i.e., index-corresponding source and target samples will be part of every minibatch. Source and target samples in the minibatch are hence not fully i.i.d., I think.