huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.99k stars 2.62k forks source link

Audio dataset load everything in RAM and is very slow #7117

Open Jourdelune opened 3 weeks ago

Jourdelune commented 3 weeks ago

Hello, I'm working with an audio dataset. I want to transcribe the audio that the dataset contain, and for that I use whisper. My issue is that the dataset load everything in the RAM when I map the dataset, obviously, when RAM usage is too high, the program crashes.

To fix this issue, I'm using writer_batch_size that I set to 10, but in this case, the mapping of the dataset is extremely slow. To illustrate this, on 50 examples, with writer_batch_size set to 10, it takes 123.24 seconds to process the dataset, but without writer_batch_size set to 10, it takes about ten seconds to process the dataset, but then the process remains blocked (I assume that it is writing the dataset and therefore suffers from the same problem as writer_batch_size)

Steps to reproduce the bug

Hug ram usage but fast (but actually slow when saving the dataset):

from datasets import load_dataset
import time

ds = load_dataset("WaveGenAI/audios2", split="train[:50]")

# map the dataset
def transcribe_audio(row):
    audio = row["audio"]  # get the audio but do nothing with it
    row["transcribed"] = True
    return row

time1 = time.time()
ds = ds.map(
    transcribe_audio
) 

for row in ds:
    pass  # do nothing, just iterate to trigger the map function

print(f"Time taken: {time.time() - time1:.2f} seconds")

Low ram usage but very very slow:

from datasets import load_dataset
import time

ds = load_dataset("WaveGenAI/audios2", split="train[:50]")

# map the dataset
def transcribe_audio(row):
    audio = row["audio"]  # get the audio but do nothing with it
    row["transcribed"] = True
    return row

time1 = time.time()
ds = ds.map(
    transcribe_audio, writer_batch_size=10
)  # set low writer_batch_size to avoid memory issues

for row in ds:
    pass  # do nothing, just iterate to trigger the map function

print(f"Time taken: {time.time() - time1:.2f} seconds")

Expected behavior

I think the processing should be much faster, on only 50 audio examples, the mapping takes several minutes while nothing is done (just loading the audio).

Environment info

Extra

The dataset has been generated by using audio folder, so I don't think anything specific in my code is causing this problem.

import argparse

from datasets import load_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--folder", help="folder path", default="/media/works/test/")
args = parser.parse_args()

dataset = load_dataset("audiofolder", data_dir=args.folder)

# push the dataset to hub
dataset.push_to_hub("WaveGenAI/audios")

Also, it's the combination of audio = row["audio"] and row["transcribed"] = True which causes problems, row["transcribed"] = Truealone does nothing and audio = row["audio"] alone sometimes causes problems, sometimes not.

lhoestq commented 3 weeks ago

Hi ! I think the issue comes from the fact that you return row entirely, and therefore the dataset has to re-encode the audio data in row.

Can you try this instead ?

# map the dataset
def transcribe_audio(row):
    audio = row["audio"]  # get the audio but do nothing with it
    return {"transcribed": True}

PS: no need to iter on the dataset to trigger the map function on a Dataset - map runs directly when it's called (contrary to IterableDataset taht you can get when streaming, which are lazy)

Jourdelune commented 2 weeks ago

No, that doesn't change anything, I manage to solve this problem by setting with_indices=True in the map function and directly retrieving the audio corresponding to the index.

from datasets import load_dataset
import time

ds = load_dataset("WaveGenAI/audios2", split="train[:50]")

# map the dataset
def transcribe_audio(row, idx):
    audio = ds[idx]["audio"]  # get the audio but do nothing with it
    row["transcribed"] = True
    return row

time1 = time.time()
ds = ds.map(
    transcribe_audio, with_indices=True
)  # set low writer_batch_size to avoid memory issues

for row in ds:
    pass  # do nothing, just iterate to trigger the map function

print(f"Time taken: {time.time() - time1:.2f} seconds")
lhoestq commented 2 weeks ago

Hmm maybe accessing row["audio"] makes map() reencode what's inside row["audio"] in case there are in-place modifications