Open busycalibrating opened 1 year ago
+1. Another use-case are class-imbalanced scenarios, where you want to under-sample majority classes and over-sample minority classes.
@crypdick I also needed this and ended up writing a custom traversal order:
from typing import Sequence
import numpy as np
from ffcv.loader import Loader
from ffcv.traversal_order.base import TraversalOrder
class WeightedRandomOrder(TraversalOrder):
"""
TraversalOrder similar to WeightedRandomSampler from PyTorch.
"""
def __init__(self, weights: Sequence[float], replacement: bool = True):
self.weights = weights
self.replacement = replacement
def __call__(self, loader: Loader):
"""
Args:
weights (Sequence):
replacement (bool): _description_. Defaults to True.
"""
super(WeightedRandomOrder, self).__init__(loader)
if self.distributed:
raise NotImplementedError("WeightedRandomOrder has no implementation for distributed.")
return self
def sample_order(self, epoch: int) -> Sequence[int]:
if not self.distributed:
generator = np.random.default_rng(self.seed + epoch if self.seed is not None else None)
return generator.choice(
self.indices, len(self.indices), replace=self.replacement, p=self.weights
)
raise NotImplementedError("WeightedRandomOrder has no implementation for distributed.")
It's a bit hacky but it works. You can initialise the order similar to RandomWeightedSampled from pytorch and then pass it to Loader as order. Im sure it can be adapted for distributed and other samplers as well.
I was wondering if there was any easy way to use a custom sampler (inheriting from
torch.utils.data.Sampler
)? At first glance it seems like I'd have to implement something that extendsffcv.traversal_order.base.TraversalOrder
? Are there any tricks to keep in mind while implementing this?My specific use case is that I'm using a dataset with real and generated images, and I know that my first
0:m
samples in the dataset are real data, and the followingm:n
samples are generated and I want to explicitly randomly sample something likex
% of the real samples, and(1-x)
% of the generated samples.I.e. converting this to support FFCV.
Thanks!