open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.72k stars 9.48k forks source link

Potential Bug of MultiDataAspectRatioBatchSampler in the mmdet v3.2.0 #11114

Open HarborYuan opened 1 year ago

HarborYuan commented 1 year ago

In the lastest upadate, mmdet introudces MultiDataAspectRatioBatchSampler. https://github.com/open-mmlab/mmdetection/blob/fe3f809a0a514189baf889aa358c498d51ee36cd/mmdet/datasets/samplers/batch_sampler.py#L120 It is a great feature that makes it possible for multiple dataset training.

However, when I use it in my own project, I found that it may introduce bugs under the DDP training. Specifically, in different ranks, the number of batch may be different and thus makes the whole job hanging.

The bug may be caused by the following code snippet: https://github.com/open-mmlab/mmdetection/blob/fe3f809a0a514189baf889aa358c498d51ee36cd/mmdet/datasets/samplers/batch_sampler.py#L164-L173 Suppose that we have 2 datasets both with 100 samples. They are splitted into 4 ranks. Each rank will have 50 samples. If the batchsize on both datasets are 2, it is expected each rank will have 25 batches (if 24-26 samples for two datasets). However, if drop_last is enabled and we have 25-25 samples for each datasets on a specific rank, the number of batches for each datasets will be 12 (the remaining 1 sample will be dropeed). The number of batches will be 24.

I wonder if anyone else has encountered this situation while using the MultiDataAspectRatioBatchSampler?

kindly @ryylcc

10926

HarborYuan commented 1 year ago

A potential solution might be all_reduce the number of batch to get the minumum batches among ranks and drop the remaining samples. But this may drastically reduce the number of samples for each rank when the number of rank is large.

num_batch = torch.tensor(len(self), device='cpu')
rank, world_size = get_dist_info()
if world_size > 1:
    group = get_default_group()
    backend_device = get_comm_device(group)
    num_batch.to(device=backend_device)
    torch_dist.all_reduce(num_batch, op=ReduceOp.MIN, group=group)
HarborYuan commented 1 year ago

Update:

The fix works in our project. But I am not very clear if there's a better solution, such as the cooperation between Sampler and BatchSampler.