lhotse-speech / lhotse

Tools for handling speech data in machine learning projects.
https://lhotse.readthedocs.io/en/latest/
Apache License 2.0
902 stars 204 forks source link

Pytorch dataloader cannot compute length #1337

Closed njellinas closed 4 weeks ago

njellinas commented 1 month ago

I have prepared a dataset with Cuts as mentioned in the tutorial:

recs = CutSet(...)
trainset = lhotse.dataset.unsupervised.UnsupervisedWaveformDataset(recs)
sampler = lhotse.dataset.sampling.SimpleCutSampler(recs, max_cuts=16, shuffle=True)
trainloader = DataLoader(trainset, sampler=sampler, batch_size=None)

I want a batch_size=16 so I have defined the max_cuts argument. But, when I calculate for my training loop the total number of iterations as len(trainloader), I get the error TypeError: object of type 'SimpleCutSampler' has no len(). When I define my own samplers without lhotse there is always a method len that calculates the total number of batches, is this not implemented here?

pzelasko commented 1 month ago

You seem to have a very outdated example, I see now that I missed a few places to update in the docs.

Samplers don't support len() because of dynamic batch sizes in lhotse. In the general case, you can't know the exact number of iterations up-front.

njellinas commented 4 weeks ago

I created this custom class and it works:

class SimpleCutSampler(lhotse.dataset.sampling.SimpleCutSampler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __len__(self):
        return int(np.ceil(self.num_cuts / self.time_constraint.max_cuts))