huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
7.37k stars 682 forks source link

Improve LeRobotDataset #440

Open alexander-soare opened 1 month ago

alexander-soare commented 1 month ago

We need to fulfil the following requirements:

Requirement Current status (HF Datasets) Proposal (Numpy memmaps)
Fast data loading ❌ Slow due to conversion from parquet format to numpy arrays ✅ Provide a much more direct path
Fast data adding/editing (for an online training buffer) ❌ Not mutable in place ✅ Can be mutated in place with regular slice assignment
Agnostic to training framework (use numpy instead of torch tensors)
No duplication between data folder and cache folder ? ?
Have more transparency on the contents of the dataset without having to download it ? ?
Be able to push to the Hugging Face hub ?
Be able to download a subset of the dataset from the hub (maybe per episode, or a slice of frames) ? ?
Be able to stream the data either to a visualization tool or for training ? ?
alexander-soare commented 1 month ago

This snippet shows a 32x faster dataset iteration time for a HF Datasets and equivalent numpy memmap. In this example we randomly access the data indices (like we would in a training loop) and we takes slices like we do in LeRobot data loading where we often need temporal chunks. Note that we could do sequential access and not take slices and there would still be a big difference in speed.

import os
import time
from contextlib import contextmanager
from pathlib import Path

import numpy as np
from datasets import Dataset, Features, Sequence, Value
from tqdm import tqdm

def _make_memmap_safe(**kwargs) -> np.memmap:
    """Make a numpy memmap with checks on available disk space first.

    Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"

    For information on dtypes:
    https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
    """
    if kwargs["mode"].startswith("w"):
        required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"])  # bytes
        stats = os.statvfs(Path(kwargs["filename"]).parent)
        available_space = stats.f_bavail * stats.f_frsize  # bytes
        if required_space >= available_space * 0.8:
            raise RuntimeError(
                f"You're about to take up {required_space} of {available_space} bytes available. This "
                "exception has been raised to protect your storage device."
                ""
            )
    return np.memmap(**kwargs)

@contextmanager
def print_time(desc: str = ""):
    start = time.perf_counter()
    yield
    print(f"{desc}: {time.perf_counter() - start:.6f}")

dataset_size = 100000
feature_dim = 6
dtype = np.dtype("float32")
shape = (dataset_size, feature_dim)
feats = np.random.normal(size=shape).astype(dtype)

np_memmap_path = Path("/tmp/np_memmep")
if np_memmap_path.exists():
    np_memmap = _make_memmap_safe(filename=np_memmap_path, mode="readwrite", dtype=dtype, shape=shape)
else:
    np_memmap = _make_memmap_safe(filename=np_memmap_path, mode="write", dtype=dtype, shape=shape)
    np_memmap[:] = feats
    np_memmap.flush()

hf_dataset_path = Path("/tmp/hf_dataset")
if hf_dataset_path.exists():
    hf_dataset = Dataset.load_from_disk(hf_dataset_path)
else:
    hf_dataset = Dataset.from_dict(
        {"data": feats},
        features=Features({"data": Sequence(Value(dtype=str(feats.dtype)), length=feature_dim)}),
    )
    hf_dataset.save_to_disk(hf_dataset_path)

slice_size = 10
np.random.seed(0)
indices = [int(i) for i in np.random.permutation(len(hf_dataset) - slice_size + 1)]

with print_time("Iterate hf_dataset"):  # 3.2 seconds
    for i in tqdm(indices):
        _ = hf_dataset[i : i + slice_size] if slice_size > 1 else hf_dataset[i]

with print_time("Iterate np_memmap"):  # 0.1 seconds
    for _ in tqdm(indices):
        _ = np_memmap[i : i + slice_size] if slice_size > 1 else np_memmap[i]

cc @lhoestq

lhoestq commented 1 month ago

Thanks for the table summary !

I feel like you are mixing storage format (e.g. parquet, mp4) and usage format (e.g. arrow, mp4/images/numpy)

Storage formats are for the Hub

Usage formats are for your local project / lib

What about defining a separate storage and usage format for your case ?


Then related to the benchmark, I'll take a look