Open edsml-hmc122 opened 1 year ago
I think I may have found a solution using torch.multinomial
.
from mmengine.registry import DATA_SAMPLERS
from mmengine.dataset.sampler import InfiniteSampler
import torch
from collections.abc import Sized
from typing import Iterator, Optional
@DATA_SAMPLERS.register_module()
class WeightedInfiniteSampler(InfiniteSampler):
def __init__(self,
dataset: Sized,
weights: torch.Tensor,
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(dataset=dataset, shuffle=shuffle, seed=seed)
self.weights = weights
# We override this method and yield with torch.multinomial:
def _infinite_indices(self) -> Iterator[int]:
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
# Weighted sampling
yield from torch.multinomial(self.weights, self.size, replacement=True, generator=g).tolist()
else:
yield from torch.arange(self.size).tolist()
And then simply in the config:
sampler = dict(
type='WeightedInfiniteSampler',
shuffle=True,
weights=sample_weights,
seed=seed,
dataset=cfg.train_dataloader.dataset
)
cfg.train_dataloader.sampler = sampler
With this sampler, the training is running and the loss is more stable. I was not yet able to verify if the sampler actually draws the samples as intended, if anyone knows a good way to do this, please comment.
Also, I am new to MMSegmentation so if this can be done better/more efficiently, please let me know!
Hi, I have an imbalanced dataset and I want to use PyTorch's
WeightedRandomSampler
or equivalent, so that each batch will be balanced during training. How can I accomplish this? Note this is not the same as using class weights to skew the loss, because with that technique, most batches would still contain only samples from the majority class.For example (this code doesn't work):
Since
WeightedRandomSampler
gives an error, is it possible to write my own Sampler and use it in the config?PyTorch documentation for WeightedRandomSampler: https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler Tutorial on how it's used (better explanation): https://www.youtube.com/watch?v=4JFVhJyTZ44
Cheers!