lhotse-speech / lhotse

Tools for handling speech data in machine learning projects.
https://lhotse.readthedocs.io/en/latest/
Apache License 2.0
935 stars 214 forks source link

dataloader slow with shar #1363

Closed tianchaolangzi closed 3 months ago

tianchaolangzi commented 3 months ago

Here I implemented a dataset `

class XXDataset(torch.utils.data.Dataset): def init(self, shar_dir, voice_types, sample_rate=16000, ir_file=None, ir_portion=0.5): self.sample_rate = sample_rate self.cuts = self.get_cuts(shar_dir) self.voices = voice_types self.voice2index = {} for idx, voice in enumerate(voice_types): self.voice2index[voice] = idx self.ir_file = ir_file self.ir_portion = ir_portion assert 0 < self.ir_portion < 1, "ir_portion set wrong"

@lru_cache()
def get_cuts(self, shar_dir):
    cut_list = glob.glob(f"{shar_dir}/cuts.*.jsonl.gz")
    recording_list = glob.glob(f"{shar_dir}/recording.*.tar")
    cut_list.sort()
    recording_list.sort()
    return CutSet.from_shar(
        fields={
            "cuts": cut_list,
            "recording": recording_list
        }
    )

# @profile
def __getitem__(self, index):
    cut = self.cuts[index]
    waveform = cut.recording.load_audio()
    waveform = waveform.reshape(-1)
    if len(waveform) > self.sample_rate * 5: # cut 5s 
        waveform = waveform[:self.sample_rate * 5]
    if self.ir_file and random.random() >= (1 - self.ir_portion):
        waveform = self.add_ir(waveform) 
    label = cut.supervisions[0].custom['audio_event'] # maybe: babycry, fire, knock
    # target = hf['target'][index_in_hdf5].astype(np.float32)
    if label in self.voices:
        target = np.array(self.voice2index[label]).astype(np.float32)
    else:
        target = np.array(self.voice2index["filler"]).astype(np.float32)
    data_dict = {'audio_name': cut.id, 'waveform': waveform, 'target': target}

    return data_dict

` But it is very slow, and the time to load data will increase batch by batch.

` Total time: 48.4381 s File:xx Function: getitem at line 35

Line # Hits Time Per Hit % Time Line Contents

35                                               @profile
36                                               def __getitem__(self, index):
37                                                   # cut = self.cuts[index].resample(self.sample_rate)
38       389   47942446.3 123245.4     99.0          cut = self.cuts[index]
39       388     486267.4   1253.3      1.0          waveform = cut.recording.load_audio()
40       388       1311.3      3.4      0.0          waveform = waveform.reshape(-1)
41       388        829.0      2.1      0.0          if len(waveform) > self.sample_rate * 5: # cut 5s 
42       322        789.3      2.5      0.0              waveform = waveform[:self.sample_rate * 5]
43       388        220.5      0.6      0.0          if self.ir_file and random.random() >= (1 - self.ir_portion):
44                                                       waveform = self.add_ir(waveform) 
45       388        592.5      1.5      0.0          label = cut.supervisions[0].custom['audio_event'] # maybe: babycry, fire, knock
46                                                   # target = hf['target'][index_in_hdf5].astype(np.float32)
47       388        472.5      1.2      0.0          if label in self.voices:
48        39        430.2     11.0      0.0              target = np.array(self.voice2index[label]).astype(np.float32)
49                                                   else:
50       349       4242.6     12.2      0.0              target = np.array(self.voice2index["filler"]).astype(np.float32)
51       388        381.7      1.0      0.0          data_dict = {'audio_name': cut.id, 'waveform': waveform, 'target': target}
52                                           
53       388         94.3      0.2      0.0          return data_dict

batch 0, load data cost 1.8195204734802246 (64, 80000) batch 1, load data cost 4.219593524932861 (64, 80000) batch 2, load data cost 6.829158067703247 (64, 80000) batch 3, load data cost 9.1092689037323 (64, 80000) batch 4, load data cost 11.50891923904419 (64, 80000) batch 5, load data cost 13.973206043243408 ` It seems that this line “cut = self.cuts[index]” takes up most of the time. I don't know why this happens. @pzelasko Do you have any suggestion? Thank you .

pzelasko commented 3 months ago

Lhotse Shar CutSet is supposed to be iterated over and not indexed. You'll need to use iterable datasets. See this tutorial for an end to end example of usage: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

tianchaolangzi commented 3 months ago

Lhotse Shar CutSet is supposed to be iterated over and not indexed. You'll need to use iterable datasets. See this tutorial for an end to end example of usage: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

Thank you so so much. The problem is solved. And I think during the training process, the 1000 samples in a shar cannot be shuffled. Shuffle can only be used at the shar level.

pzelasko commented 3 months ago

You can shuffle, just call cuts.shuffle(buffer_size=10000) (unless you're using a lhotse Sampler which will do it for you). It performs approximate streaming shuffling.