open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.23k stars 2.61k forks source link

How to use WeightedRandomSampler from PyTorch or write custom sampler #3104

Open edsml-hmc122 opened 1 year ago

edsml-hmc122 commented 1 year ago

Hi, I have an imbalanced dataset and I want to use PyTorch's WeightedRandomSampler or equivalent, so that each batch will be balanced during training. How can I accomplish this? Note this is not the same as using class weights to skew the loss, because with that technique, most batches would still contain only samples from the majority class.

For example (this code doesn't work):

cfg.train_dataloader.sampler = dict(
    type='WeightedRandomSampler',
    shuffle=True,
    weights=sample_weights,  # 1 weight for each sample in the dataset
    num_samples=len(sample_weights),
    replacement=True  # True for oversampling
)

Since WeightedRandomSampler gives an error, is it possible to write my own Sampler and use it in the config?

PyTorch documentation for WeightedRandomSampler: https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler Tutorial on how it's used (better explanation): https://www.youtube.com/watch?v=4JFVhJyTZ44

Cheers!

edsml-hmc122 commented 1 year ago

I think I may have found a solution using torch.multinomial.

from mmengine.registry import DATA_SAMPLERS
from mmengine.dataset.sampler import InfiniteSampler
import torch
from collections.abc import Sized
from typing import Iterator, Optional

@DATA_SAMPLERS.register_module()
class WeightedInfiniteSampler(InfiniteSampler):
    def __init__(self,
                 dataset: Sized,
                 weights: torch.Tensor,
                 shuffle: bool = True,
                 seed: Optional[int] = None) -> None:
        super().__init__(dataset=dataset, shuffle=shuffle, seed=seed)
        self.weights = weights

    # We override this method and yield with torch.multinomial:
    def _infinite_indices(self) -> Iterator[int]:
        g = torch.Generator()
        g.manual_seed(self.seed)
        while True:
            if self.shuffle:
                # Weighted sampling
                yield from torch.multinomial(self.weights, self.size, replacement=True, generator=g).tolist()
            else:
                yield from torch.arange(self.size).tolist()

And then simply in the config:

sampler = dict(
    type='WeightedInfiniteSampler',
    shuffle=True,
    weights=sample_weights,
    seed=seed,
    dataset=cfg.train_dataloader.dataset
)

cfg.train_dataloader.sampler = sampler

With this sampler, the training is running and the loss is more stable. I was not yet able to verify if the sampler actually draws the samples as intended, if anyone knows a good way to do this, please comment.

Also, I am new to MMSegmentation so if this can be done better/more efficiently, please let me know!