huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.31k stars 2.7k forks source link

Slow iteration speeds when using IterableDataset.shuffle with load_dataset(data_files=..., streaming=True) #7102

Open lajd opened 3 months ago

lajd commented 3 months ago

Describe the bug

When I load a dataset from a number of arrow files, as in:

random_dataset = load_dataset(
    "arrow",
    data_files={split: shard_filepaths},
    streaming=True,
    split=split,
)

I'm able to get fast iteration speeds when iterating over the dataset without shuffling.

When I shuffle the dataset, the iteration speed is reduced by ~1000x.

It's very possible the way I'm loading dataset shards is not appropriate; if so please advise!

Thanks for the help

Steps to reproduce the bug

Here's full code to reproduce the issue:

import time
from pathlib import Path
from multiprocessing import Pool, cpu_count

import torch
from datasets import Dataset, load_dataset

split = "train"
split_save_dir = "/tmp/random_split"

def generate_random_example():
    return {
        'inputs': torch.randn(128).tolist(),
        'indices': torch.randint(0, 10000, (2, 20000)).tolist(),
        'values': torch.randn(20000).tolist(), 
    }

def generate_shard_dataset(examples_per_shard: int = 512):
    dataset_dict = {
        'inputs': [],
        'indices': [],
        'values': []
    }

    for _ in range(examples_per_shard):
        example = generate_random_example()
        dataset_dict['inputs'].append(example['inputs'])
        dataset_dict['indices'].append(example['indices'])
        dataset_dict['values'].append(example['values'])

    return Dataset.from_dict(dataset_dict)

def save_shard(shard_idx, save_dir, examples_per_shard):
    shard_dataset = generate_shard_dataset(examples_per_shard)
    shard_write_path = Path(save_dir) / f"shard_{shard_idx}"
    shard_dataset.save_to_disk(shard_write_path)
    return str(Path(shard_write_path) / "data-00000-of-00001.arrow")

def generate_split_shards(save_dir, num_shards: int = 16, examples_per_shard: int = 512):
    with Pool(cpu_count()) as pool:
        args = [(m, save_dir, examples_per_shard) for m in range(num_shards)]
        shard_filepaths = pool.starmap(save_shard, args)

    return shard_filepaths

shard_filepaths = generate_split_shards(split_save_dir)

Load the dataset as IterableDataset:

random_dataset = load_dataset(
    "arrow",
    data_files={split: shard_filepaths},
    streaming=True,
    split=split,
)
random_dataset = random_dataset.with_format("numpy")

Observe the iterations/second when iterating over the dataset directly, and applying shuffling before iterating:

Without shuffling, this gives ~1500 iterations/second

start_time = time.time()
for count, item in enumerate(random_dataset):
    if count > 0 and count % 100 == 0:
        elapsed_time = time.time() - start_time
        iterations_per_second = count / elapsed_time
        print(f"Processed {count} items at an average of {iterations_per_second:.2f} iterations/second")
Processed 100 items at an average of 705.74 iterations/second
Processed 200 items at an average of 1169.68 iterations/second
Processed 300 items at an average of 1497.97 iterations/second
Processed 400 items at an average of 1739.62 iterations/second
Processed 500 items at an average of 1931.11 iterations/second`

When shuffling, this gives ~3 iterations/second:


random_dataset = random_dataset.shuffle(buffer_size=100,seed=42)

start_time = time.time()
for count, item in enumerate(random_dataset):
    if count > 0 and count % 100 == 0:
        elapsed_time = time.time() - start_time
        iterations_per_second = count / elapsed_time
        print(f"Processed {count} items at an average of {iterations_per_second:.2f} iterations/second")
Processed 100 items at an average of 3.75 iterations/second
Processed 200 items at an average of 3.93 iterations/second

Expected behavior

Iterations per second should be barely affected by shuffling, especially with a small buffer size

Environment info

Datasets version: 2.21.0 Python 3.10 Ubuntu 22.04

brijow commented 3 months ago

Hi @lajd , I was skeptical about how we are saving the shards each as their own dataset (arrow file) in the script above, and so I updated the script to try out saving the shards in a few different file formats. From the experiments I ran, I saw binary format show significantly the best performance, with arrow and parquet about the same. However, I was unable to reproduce a drastically slower iteration speed after shuffling in any case when using the revised script -- pasting below:

import time
from datasets import load_dataset, Dataset, IterableDataset
from pathlib import Path
import torch
import pandas as pd
import pickle
import pyarrow as pa
import pyarrow.parquet as pq

def generate_random_example():
    return {
        'inputs': torch.randn(128).tolist(),
        'indices': torch.randint(0, 10000, (2, 20000)).tolist(),
        'values': torch.randn(20000).tolist(),
    }

def generate_shard_data(examples_per_shard: int = 512):
    return [generate_random_example() for _ in range(examples_per_shard)]

def save_shard_as_arrow(shard_idx, save_dir, examples_per_shard):
    # Generate shard data
    shard_data = generate_shard_data(examples_per_shard)

    # Convert data to a Hugging Face Dataset
    dataset = Dataset.from_dict({
        'inputs': [example['inputs'] for example in shard_data],
        'indices': [example['indices'] for example in shard_data],
        'values': [example['values'] for example in shard_data],
    })

    # Define the shard save path
    shard_write_path = Path(save_dir) / f"shard_{shard_idx}"

    # Save the dataset to disk using the Arrow format
    dataset.save_to_disk(str(shard_write_path))

    return str(shard_write_path)

def save_shard_as_parquet(shard_idx, save_dir, examples_per_shard):
    # Generate shard data
    shard_data = generate_shard_data(examples_per_shard)

    # Convert data to a pandas DataFrame for easy conversion to Parquet
    df = pd.DataFrame(shard_data)

    # Define the shard save path
    shard_write_path = Path(save_dir) / f"shard_{shard_idx}.parquet"

    # Convert DataFrame to PyArrow Table for Parquet saving
    table = pa.Table.from_pandas(df)

    # Save the table as a Parquet file
    pq.write_table(table, shard_write_path)

    return str(shard_write_path)

def save_shard_as_binary(shard_idx, save_dir, examples_per_shard):
    # Generate shard data
    shard_data = generate_shard_data(examples_per_shard)

    # Define the shard save path
    shard_write_path = Path(save_dir) / f"shard_{shard_idx}.bin"

    # Save each example as a serialized binary object using pickle
    with open(shard_write_path, 'wb') as f:
        for example in shard_data:
            f.write(pickle.dumps(example))

    return str(shard_write_path)

def generate_split_shards(save_dir, filetype="parquet", num_shards: int = 16, examples_per_shard: int = 512):
    shard_filepaths = []
    for shard_idx in range(num_shards):
        if filetype == "parquet":
            shard_filepaths.append(save_shard_as_parquet(shard_idx, save_dir, examples_per_shard))
        elif filetype == "binary":
            shard_filepaths.append(save_shard_as_binary(shard_idx, save_dir, examples_per_shard))
        elif filetype == "arrow":
            shard_filepaths.append(save_shard_as_arrow(shard_idx, save_dir, examples_per_shard))
        else:
            raise ValueError(f"Unsupported filetype: {filetype}. Choose either 'parquet' or 'binary'.")
    return shard_filepaths

def _binary_dataset_generator(files):
    for filepath in files:
        with open(filepath, 'rb') as f:
            while True:
                try:
                    example = pickle.load(f)
                    yield example
                except EOFError:
                    break

def load_binary_dataset(shard_filepaths):
    return IterableDataset.from_generator(
        _binary_dataset_generator, gen_kwargs={"files": shard_filepaths},
    )

def load_parquet_dataset(shard_filepaths):
    # Load the dataset as an IterableDataset
    return load_dataset(
        "parquet",
        data_files={split: shard_filepaths},
        streaming=True,
        split=split,
    )

def load_arrow_dataset(shard_filepaths):
    # Load the dataset as an IterableDataset
    shard_filepaths = [f + "/data-00000-of-00001.arrow" for f in shard_filepaths]
    return load_dataset(
        "arrow",
        data_files={split: shard_filepaths},
        streaming=True,
        split=split,
    )

def load_dataset_wrapper(filetype: str, shard_filepaths: list[str]):
    if filetype == "parquet":
        return load_parquet_dataset(shard_filepaths)
    if filetype == "binary":
        return load_binary_dataset(shard_filepaths)
    if filetype == "arrow":
        return load_arrow_dataset(shard_filepaths)
    else:
        raise ValueError("Unsupported filetype")

# Example usage:
split = "train"
split_save_dir = "/tmp/random_split"

filetype = "binary" # or "parquet", or "arrow"
num_shards = 16

shard_filepaths = generate_split_shards(split_save_dir, filetype=filetype, num_shards=num_shards)
dataset = load_dataset_wrapper(filetype=filetype, shard_filepaths=shard_filepaths)

dataset = dataset.shuffle(buffer_size=100, seed=42)

start_time = time.time()
for count, item in enumerate(dataset):
    if count > 0 and count % 100 == 0:
        elapsed_time = time.time() - start_time
        iterations_per_second = count / elapsed_time
        print(f"Processed {count} items at an average of {iterations_per_second:.2f} iterations/second")
brijow commented 3 months ago

update: I was able to reproduce the issue you described -- but ONLY if I do

random_dataset = random_dataset.with_format("numpy")

If I do this, I see similar numbers as what you reported. If I do not use numpy format, parquet and arrow are about 17 iterations per second regardless of whether or not we shuffle. Using binary, (again no numpy format tried with this yet), still shows the fastest speeds on average (shuffle and no shuffle) of about 850 it/sec.

I suspect some issues with arrow and numpy being optimized for sequential reads, and shuffling cuases issuses... hmm