pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.94k stars 21.99k forks source link

Compatibility of subset dataset with disabled batch sampling #37409

Open dvirginz opened 4 years ago

dvirginz commented 4 years ago

I think there is a compatibility issue with disabled batch sampling and subset dataset The use-case - define custom batch sampling, and split the dataset using PyTorch split utility function Here's a minimal working example

self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
            self.dataset, [100, 100, 100])

loader = DataLoader(
                dataset=self.train_dataset,
                batch_size=None,
                batch_sampler=None,
                sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )

And when iterating the subset datasets this is the error

Exception has occurred: TypeError
list indices must be integers or slices, not list
  File "/path/utils/data/dataset.py", line 257, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/path/utils/data/_utils/fetch.py", line 46, in fetch
    data = self.dataset[possibly_batched_index]
  File "/path/utils/data/dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/path/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/path/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 251, in _evaluate
    for batch_idx, batch in enumerate(dataloader):
  File "/path/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 843, in run_pretrain_routine
    False)
  File "/path/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 477, in single_gpu_train
    self.run_pretrain_routine(model)
  File "/path/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 704, in fit
    self.single_gpu_train(model)
  File "/path/train.py", line 152, in main_train
    trainer.fit(model)
  File "/path/train.py", line 66, in main
    main_train(model_class_pointer, hyperparams, logger)
  File "/path/train.py", line 161, in <module>
    main()
  File "/path/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/path/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/path/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/path/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/path/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)

As the self.indices of the subset object is a simple python list.

Please refer to the forum post, where @ptrblck approves reproducing the bug and offering a workaround.

cc @SsnL

kshitij12345 commented 4 years ago

@dvirginz Can you add link to the forum post.

colesbury commented 4 years ago

https://discuss.pytorch.org/t/compatibility-of-subset-dataset-with-disabled-batch-sampling/78793

Quoting @ptrblck's suggested workaround:

As a workaround you could use a SubsetRandomSampler and pass the shuffled indices to it. Inside the Dataset.getitem you might need to create a single index tensor via:

index = torch.stack(index) x = self.data[index]

since the SubsetRandomSampler will pass a list of tensors to the Dataset for the BatchSampler approach.

dvirginz commented 4 years ago

The workaround is only partial, as most of the time with remote datasets the random sampler is a bad approach (sequential queries are much cheaper). I ended implementing a subset sequential sampler:) But yes:)

shuklaayush commented 3 years ago

I'm facing the exact same issue. I guess it can be fixed by storing the Subset indices as a tensor/numpy array instead of a plain Python list (using .numpy() instead of .tolist() here)