huggingface / diarizers

215 stars 15 forks source link

Preprocessing consumes all available memory #6

Open Jamiroquai88 opened 1 month ago

Jamiroquai88 commented 1 month ago

Hello, I wanted to try diarizers repo to fine-tune the model on my data. I slightly rewrote the training script to load my own data, e.g.

dataset = SpeakerDiarizationDataset(
        {
            'train': train_audio_files,
            'validation': dev_audio_files
        },
        {
            'train': train_rttm_files,
            'validation': dev_rttm_files
        }).construct_dataset()

I am able to train on a smaller set, but once I use more data, preprocessing consumes all available (1T+) memory. It is only at 4% progress at that point. Any pointers on where to look, so I can fix this? I must say that I use a lot of data, in this case probably more than 60.000 hours.

I trained pyannote3 segmentation models before, but with my own scripts and I never had this issue.

Jamiroquai88 commented 1 month ago

So I need to confirm that even when using my dev set (550 hours) to train I see very similar behaviour. In this case, it finished successfully, but memory ramped up to 100% available, probably more than 1.2T. I would expect this to work.

kamilakesbi commented 1 month ago

Hi @Jamiroquai88,

Thank you for submitting this issue!

Would you agree to share with us your data (at least your dev set) so that we try to reproduce this behavior ? it will be hard for us to debug if we can't reproduce it.

Here are a few pointers to look at:

Please let us know if you any of these pointers helped you solve this issue!

Jamiroquai88 commented 1 month ago

Hello, unfortunately, I am not able to share my data. Some statistics so you have a better understanding

I tried to create a random subset of my dev, some numbers:

So it seems like your preprocessing is causing the issue. Not sure why you would load the whole dataset at once rather than on the fly. Is there an easy way to fix this or find a workaround?

kamilakesbi commented 1 month ago

Hi @Jamiroquai88,

Thank you for this information. I'll try to replicate your behaviour with a random dataset.

One possibility would be to use an IterableDataset, which is more convenient for datasets with hundreds of GBs. You could then apply the .map function to your dataset, which is performed on-the-fly when iterating over it (see here). The output IterableDataset can be then passed to the Trainer.

Let me know if you need help with this and if it works!

Jamiroquai88 commented 1 month ago

Hi @kamilakesbi, so I decided to try IterableDataset, but I am having issues. I tried to write an iterator that yields a single audio with timestamps, etc:

def iterable_dataset_generator(annotations_paths, audio_paths, diarization_dataset_object):
    for annotation, audio in zip(annotations_paths, audio_paths):
        if diarization_dataset_object.annotations_type == "rttm":
            timestamps_start_file, timestamps_end_file, speakers_file = (
                diarization_dataset_object.process_rttm_file(annotation))
        elif diarization_dataset_object.annotations_type == "cha":
            timestamps_start_file, timestamps_end_file, speakers_file = (
                diarization_dataset_object.process_cha_file(annotation))
        else:
            raise ValueError("Unsupported annotations type")

        yield {
            "audio": audio,
            "timestamps_start": timestamps_start_file,
            "timestamps_end": timestamps_end_file,
            "speakers": speakers_file,
        }

Initializing like this

def construct_dataset(self, num_proc=1):
    """Main method to construct the dataset

    Returns:
        self.spd_dataset: HF dataset compatible with diarizers.
    """
    for subset in self.audio_paths:
        self.spd_dataset[subset] = IterableDataset.from_generator(
            iterable_dataset_generator,
            gen_kwargs={
                'annotations_paths': self.annotations_paths[subset],
                'audio_paths': self.audio_paths[subset],
                'diarization_dataset_object': self
            }
        )

In all the examples that I see people do

self.spd_dataset[subset] = self.spd_dataset.cast_column("audio", Audio(sampling_rate=self.sample_rate))

but this gives me an error

Traceback (most recent call last):
  File "/shared/jprofant/Github/diarizers/train_segmentation_rev.py", line 89, in <module>
    }).construct_dataset()
  File "/shared/jprofant/Github/diarizers/src/diarizers/data/speaker_diarization.py", line 200, in construct_dataset
    self.spd_dataset[subset] = self.spd_dataset.cast_column("audio", Audio(sampling_rate=self.sample_rate))
  File "/home/ubuntu/miniconda3/envs/diarizers/lib/python3.10/site-packages/datasets/dataset_dict.py", line 2255, in cast_column
    {k: dataset.cast_column(column=column, feature=feature) for k, dataset in self.items()}
  File "/home/ubuntu/miniconda3/envs/diarizers/lib/python3.10/site-packages/datasets/dataset_dict.py", line 2255, in <dictcomp>
    {k: dataset.cast_column(column=column, feature=feature) for k, dataset in self.items()}
  File "/home/ubuntu/miniconda3/envs/diarizers/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 2131, in cast_column
    info.features[column] = feature
TypeError: 'NoneType' object does not support item assignment

So I guess that audio is not loaded yet. When I don't do .cast_column() I go further in the code, but again, not sure what to do about audio (since it probably isn't loaded):

File "/shared/jprofant/Github/diarizers/train_segmentation_rev.py", line 131, in <lambda>
    lambda file: preprocessor(file, random=False, overlap=0.0),
  File "/shared/jprofant/Github/diarizers/src/diarizers/data/preprocess.py", line 204, in __call__
    start_positions = self.get_start_positions(file, overlap)
  File "/shared/jprofant/Github/diarizers/src/diarizers/data/preprocess.py", line 173, in get_start_positions
    sample_rate = file["audio"][0]["sampling_rate"]
TypeError: string indices must be integers

Any idea what I am doing wrong? I guess I should somehow load audio inside the generator, but not sure how. Even in this example they use .cast_column() https://huggingface.co/docs/datasets/audio_load

kamilakesbi commented 1 month ago

Hi @Jamiroquai88,

Could you share with me a min reproducer of the error you get ? I can try to reproduce using the Callhome dataset.

audio = {
     'array': audio np array, 
     'sample_rate': sampling rate (int) 
} 

This would be to make sure that you can load them correctly before using .cast_column().

Hope it will help you!

Jamiroquai88 commented 1 month ago

I created a PR here: https://github.com/huggingface/diarizers/pull/9/files Please note that when using callhome in your training script it does not use .construct_dataset() Feel free to use my script, should be pretty straightforward - just put all wavs and rttms into one directory.

Jamiroquai88 commented 3 weeks ago

Hello @kamilakesbi, did you have a chance to look at my PR?

kamilakesbi commented 2 weeks ago

Hi @Jamiroquai88 not yet sorry! I'll try to have a look at it early next week :)