wenet-e2e / wespeaker

Research and Production Oriented Speaker Verification, Recognition and Diarization Toolkit
Apache License 2.0
630 stars 109 forks source link

Is it possible to implement dataloader for triplet loss/GE2E #235

Open mmmmayi opened 9 months ago

mmmmayi commented 9 months ago

Hi, is it possible to make a batch containing M speaker and N utterances for each speaker?

cdliang11 commented 9 months ago

I don't think it can be supported in UIO mode. But in dataset_deprecated.py, it may be implemented.

mmmmayi commented 8 months ago

Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?

JiJiJiang commented 8 months ago

Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?

Check our paper for introduction of the UIO data management. We design this mode for large dataset training. BTW, DistributedSampler is designed for shuffling the data.list and distributes them into different GPUs. It might not satisfy your demands. I think a proper way is to design your own collate_fn for the dataloader. You can refer to our implementation in DINO ssl training codes in collate_fn. Of course some modification is needed.

mmmmayi commented 8 months ago

Thanks for your response but what's UIO mode? And I think maybe it can be implemented in class DistributedSampler in dataset.py?

Check our paper for introduction of the UIO data management. We design this mode for large dataset training. BTW, DistributedSampler is designed for shuffling the data.list and distributes them into different GPUs. It might not satisfy your demands. I think a proper way is to design your own collate_fn for the dataloader. You can refer to our implementation in DINO ssl training codes in collate_fn. Of course some modification is needed.

I understand DistributedSampler is designed for distributing data into different GPUs. But can we distribute according to the spk id? For example, the code can be (and remove shuffle in processor.py):

def __init__(self, lists, num_utts, shuffle=True, partition=True):
    self.epoch = -1
    self.update()
    self.shuffle = shuffle
    self.partition = partition
    self.num_utts = num_utts
    self.spk={}
    for i in range(len(lists)):
        obj = json.loads(lists[i])
        if obj['spk'] not in self.spk:
            self.spk[obj['spk']]=[]
        self.spk[obj['spk']].append(i)
def update(self):
    assert dist.is_available()
    if dist.is_initialized():
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
    else:
        self.rank = 0
        self.world_size = 1
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
        self.worker_id = 0
        self.num_workers = 1
    else:
        self.worker_id = worker_info.id
        self.num_workers = worker_info.num_workers
    return dict(rank=self.rank,
                world_size=self.world_size,
                worker_id=self.worker_id,
                num_workers=self.num_workers)

def set_epoch(self, epoch):
    self.epoch = epoch

def sample(self,lists):

    spk = list(self.spk.keys())
    if self.partition:
        if self.shuffle:
            random.Random(self.epoch).shuffle(spk)
        spk = spk[self.rank::self.world_size]
    spk = spk[self.worker_id::self.num_workers]
    data = []
    for i in spk:
            data=data+random.choices(self.spk[i], k=self.num_utts)
    return data
mmmmayi commented 8 months ago

I just noticed that I set the data type as 'raw', and the above code is not appropriate for 'shard'

JiJiJiang commented 8 months ago

Yeah... It makes senses in the 'raw' mode. Hope it works for you! Good luck!

wsstriving commented 8 months ago

I just noticed that I set the data type as 'raw', and the above code is not appropriate for 'shard'

In raw mode, it's much easier to implement your function(but slow). But in shard mode it's also possible except that it takes some efforts for the implementation.

Possible approach:

  1. Rewrite the write_shard function to make sure each shard contains multiple occurances from the same speaker
  2. As suggested in other comments, the easiest way is to implement it in the colloate_fn, but samples in one batch is limited. One workaroud might be: If you notice the shuffle function in the processor, you can actually write a simliar processor and do something in the buffer, which can be set much larger than the batch-size. Then the collate_fn can handle real batch organization.

But overall, you need to balance the randomness and data processing difficulty.