microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.8k stars 353 forks source link

Add DistributedGeoSamplers #305

Open RitwikGupta opened 2 years ago

RitwikGupta commented 2 years ago

set_epoch is a used by a lot of codebases in an effort to be deterministic

adamjstewart commented 2 years ago

a lot of codebases

Can you give some examples? I would like to see how they implement this.

RitwikGupta commented 2 years ago

https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L270 https://github.com/microsoft/Swin-Transformer/blob/b05e6214a37d33846903585c9e83b694ef411587/main.py#L142 https://github.com/facebookresearch/maskrcnn-benchmark/blob/57eec25b75144d9fb1a6857f32553e1574177daf/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py#L23

And more: https://github.com/search?l=Python&q=sampler.set_epoch&type=Code

adamjstewart commented 2 years ago

Okay, so all of the examples above are using torch.utils.data.distributed.DistributedSampler. This feature was added in https://github.com/pytorch/pytorch/pull/39628 with little explanation other than what is in the docs/source code. This has a comment:

In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

The set_epoch method docstring (which doesn't make it into the docs for some reason) contains:

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

TL;DR: So this feature is not used for reproducibility/determinism, but for randomness across replicas. It's unclear to me why PyTorch needs this to prevent subsequent iterations from having the same ordering. Until I understand that, I'm not sure whether we need this feature or not.

Our samplers definitely weren't created with distributed sampling in mind, but they should still work since our datasets don't have a finite length that needs to be subsampled (other than GridGeoSampler which would need special treatment). RandomGeoSampler and RandomBatchGeoSampler imply shuffle=True.

P.S. The other non-distributed samplers still have a generator arg that can be used to control random sampling. We may want to add this.

isaaccorley commented 2 years ago

PyTorch Lightning automatically wraps samplers as DistributedSamplers so we don't need to handle any of this since we use PL. You would only need to mess with this if you were rolling your own distributed training scripts. See https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#replace-sampler-ddp

adamjstewart commented 2 years ago

@RitwikGupta does PL satisfy your use case?

RitwikGupta commented 2 years ago

@adamjstewart not entirely, I'd have to refactor this entire FB codebase into PyTorch Lightning, which would be a massive pain.

isaaccorley commented 2 years ago

Adding a set_epoch method to our samplers wouldn't actually solve this. The above links use DistributedSamplers which splits up dataset indices to be sampled across nodes/gpus. I've done this for another project but we would need to create our own modification of a DistributedSamplerWrapper similarly to this

Edit: torch.utils.data.DistributedSampler also provides some insight. The reason set_epoch exists is because the epoch is used when setting the seed for shuffling the indices across nodes/gpus.

RitwikGupta commented 2 years ago

Right, I should say that set_epoch isn't the fix here, the larger fix is to create a distributed wrapper for the existing samplers. The issue is just to file the symptom.

dylanrstewart commented 1 year ago

I think this is still an issue. I am unable to train using multiple gpus when I am leveraging a RandomBatchGeoSampler AttributeError: 'RandomBatchGeoSampler' object has no attribute 'drop_last'