necla-ml / ML

BSD 3-Clause "New" or "Revised" License
1 stars 0 forks source link

Weighted sampler for multiple datasets #4

Open deepsworld opened 2 years ago

deepsworld commented 2 years ago

A weighted sampler that samples data from multiple datasets for ease of combining from different datasets.

deepsworld commented 2 years ago

I have a rough implementation for it.

import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader, WeightedRandomSampler

class custom_dataset0(Dataset):
    def __init__(self):
        super().__init__()
        self.tensor_data = torch.tensor([i for i in range(80)])

    def __getitem__(self, index):
        return self.tensor_data[index], torch.tensor(0)

    def __len__(self):
        return len(self.tensor_data)

class custom_dataset1(Dataset):
    def __init__(self):
        super().__init__()
        self.tensor_data = torch.tensor([i for i in range(20)])

    def __getitem__(self, index):
        return self.tensor_data[index], torch.tensor(1)

    def __len__(self):
        return len(self.tensor_data)

dataset0 = custom_dataset0()
dataset1 = custom_dataset1()

datasets = [dataset0, dataset1]
concat_dataset = ConcatDataset(datasets)
lengths = torch.tensor([len(dataset) for dataset in datasets])
# calculate weights based on length of each dataset
dataset_weights = 1 / lengths
# dataset_weights = [0.2, 0.8] # can also use custom weights
weights = torch.ones(lengths.sum().item(), dtype=torch.float32)
indice = 0
for i, idx in enumerate(lengths):
    weights[indice:indice + idx] = dataset_weights[i] 
    indice += idx
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
dataloader = DataLoader(concat_dataset, batch_size=16, sampler=sampler)
for i, data in enumerate(dataloader):
    val, dataset_no = data
    print("batch index {}, dataset0/dataset1: {}/{}".format(i, (dataset_no == 0).sum(), (dataset_no == 1).sum()))