Open lajd opened 3 months ago
Hi @lajd , I was skeptical about how we are saving the shards each as their own dataset (arrow file) in the script above, and so I updated the script to try out saving the shards in a few different file formats. From the experiments I ran, I saw binary format show significantly the best performance, with arrow and parquet about the same. However, I was unable to reproduce a drastically slower iteration speed after shuffling in any case when using the revised script -- pasting below:
import time
from datasets import load_dataset, Dataset, IterableDataset
from pathlib import Path
import torch
import pandas as pd
import pickle
import pyarrow as pa
import pyarrow.parquet as pq
def generate_random_example():
return {
'inputs': torch.randn(128).tolist(),
'indices': torch.randint(0, 10000, (2, 20000)).tolist(),
'values': torch.randn(20000).tolist(),
}
def generate_shard_data(examples_per_shard: int = 512):
return [generate_random_example() for _ in range(examples_per_shard)]
def save_shard_as_arrow(shard_idx, save_dir, examples_per_shard):
# Generate shard data
shard_data = generate_shard_data(examples_per_shard)
# Convert data to a Hugging Face Dataset
dataset = Dataset.from_dict({
'inputs': [example['inputs'] for example in shard_data],
'indices': [example['indices'] for example in shard_data],
'values': [example['values'] for example in shard_data],
})
# Define the shard save path
shard_write_path = Path(save_dir) / f"shard_{shard_idx}"
# Save the dataset to disk using the Arrow format
dataset.save_to_disk(str(shard_write_path))
return str(shard_write_path)
def save_shard_as_parquet(shard_idx, save_dir, examples_per_shard):
# Generate shard data
shard_data = generate_shard_data(examples_per_shard)
# Convert data to a pandas DataFrame for easy conversion to Parquet
df = pd.DataFrame(shard_data)
# Define the shard save path
shard_write_path = Path(save_dir) / f"shard_{shard_idx}.parquet"
# Convert DataFrame to PyArrow Table for Parquet saving
table = pa.Table.from_pandas(df)
# Save the table as a Parquet file
pq.write_table(table, shard_write_path)
return str(shard_write_path)
def save_shard_as_binary(shard_idx, save_dir, examples_per_shard):
# Generate shard data
shard_data = generate_shard_data(examples_per_shard)
# Define the shard save path
shard_write_path = Path(save_dir) / f"shard_{shard_idx}.bin"
# Save each example as a serialized binary object using pickle
with open(shard_write_path, 'wb') as f:
for example in shard_data:
f.write(pickle.dumps(example))
return str(shard_write_path)
def generate_split_shards(save_dir, filetype="parquet", num_shards: int = 16, examples_per_shard: int = 512):
shard_filepaths = []
for shard_idx in range(num_shards):
if filetype == "parquet":
shard_filepaths.append(save_shard_as_parquet(shard_idx, save_dir, examples_per_shard))
elif filetype == "binary":
shard_filepaths.append(save_shard_as_binary(shard_idx, save_dir, examples_per_shard))
elif filetype == "arrow":
shard_filepaths.append(save_shard_as_arrow(shard_idx, save_dir, examples_per_shard))
else:
raise ValueError(f"Unsupported filetype: {filetype}. Choose either 'parquet' or 'binary'.")
return shard_filepaths
def _binary_dataset_generator(files):
for filepath in files:
with open(filepath, 'rb') as f:
while True:
try:
example = pickle.load(f)
yield example
except EOFError:
break
def load_binary_dataset(shard_filepaths):
return IterableDataset.from_generator(
_binary_dataset_generator, gen_kwargs={"files": shard_filepaths},
)
def load_parquet_dataset(shard_filepaths):
# Load the dataset as an IterableDataset
return load_dataset(
"parquet",
data_files={split: shard_filepaths},
streaming=True,
split=split,
)
def load_arrow_dataset(shard_filepaths):
# Load the dataset as an IterableDataset
shard_filepaths = [f + "/data-00000-of-00001.arrow" for f in shard_filepaths]
return load_dataset(
"arrow",
data_files={split: shard_filepaths},
streaming=True,
split=split,
)
def load_dataset_wrapper(filetype: str, shard_filepaths: list[str]):
if filetype == "parquet":
return load_parquet_dataset(shard_filepaths)
if filetype == "binary":
return load_binary_dataset(shard_filepaths)
if filetype == "arrow":
return load_arrow_dataset(shard_filepaths)
else:
raise ValueError("Unsupported filetype")
# Example usage:
split = "train"
split_save_dir = "/tmp/random_split"
filetype = "binary" # or "parquet", or "arrow"
num_shards = 16
shard_filepaths = generate_split_shards(split_save_dir, filetype=filetype, num_shards=num_shards)
dataset = load_dataset_wrapper(filetype=filetype, shard_filepaths=shard_filepaths)
dataset = dataset.shuffle(buffer_size=100, seed=42)
start_time = time.time()
for count, item in enumerate(dataset):
if count > 0 and count % 100 == 0:
elapsed_time = time.time() - start_time
iterations_per_second = count / elapsed_time
print(f"Processed {count} items at an average of {iterations_per_second:.2f} iterations/second")
update: I was able to reproduce the issue you described -- but ONLY if I do
random_dataset = random_dataset.with_format("numpy")
If I do this, I see similar numbers as what you reported. If I do not use numpy format, parquet and arrow are about 17 iterations per second regardless of whether or not we shuffle. Using binary, (again no numpy format tried with this yet), still shows the fastest speeds on average (shuffle and no shuffle) of about 850 it/sec.
I suspect some issues with arrow and numpy being optimized for sequential reads, and shuffling cuases issuses... hmm
Describe the bug
When I load a dataset from a number of arrow files, as in:
I'm able to get fast iteration speeds when iterating over the dataset without shuffling.
When I shuffle the dataset, the iteration speed is reduced by ~1000x.
It's very possible the way I'm loading dataset shards is not appropriate; if so please advise!
Thanks for the help
Steps to reproduce the bug
Here's full code to reproduce the issue:
Load the dataset as IterableDataset:
Observe the iterations/second when iterating over the dataset directly, and applying shuffling before iterating:
Without shuffling, this gives ~1500 iterations/second
When shuffling, this gives ~3 iterations/second:
Expected behavior
Iterations per second should be barely affected by shuffling, especially with a small buffer size
Environment info
Datasets version: 2.21.0 Python 3.10 Ubuntu 22.04