Fast, accurate streaming of training data from cloud storage
[Website] - [Getting Started] - [Docs] - [We're Hiring!]
We built StreamingDataset to make training on large datasets from cloud storage as fast, cheap, and scalable as possible.
It’s specially designed for multi-node, distributed training for large models—maximizing correctness guarantees, performance, and ease of use. Now, you can efficiently train anywhere, independent of your training data location. Just stream in the data you need, when you need it. To learn more about why we built StreamingDataset, read our announcement blog.
StreamingDataset is compatible with any data type, including images, text, video, and multimodal data.
With support for major cloud storage providers (AWS, OCI, GCS, Azure, Databricks, and any S3 compatible object store such as Cloudflare R2, Coreweave, Backblaze b2, etc. ) and designed as a drop-in replacement for your PyTorch IterableDataset class, StreamingDataset seamlessly integrates into your existing training workflows.
Streaming can be installed with pip
:
pip install mosaicml-streaming
Convert your raw dataset into one of our supported streaming formats:
import numpy as np
from PIL import Image
from streaming import MDSWriter
# Local or remote directory in which to store the compressed output files
data_dir = 'path-to-dataset'
# A dictionary mapping input fields to their data types
columns = {
'image': 'jpeg',
'class': 'int'
}
# Shard compression, if any
compression = 'zstd'
# Save the samples as shards using MDSWriter
with MDSWriter(out=data_dir, columns=columns, compression=compression) as out:
for i in range(10000):
sample = {
'image': Image.fromarray(np.random.randint(0, 256, (32, 32, 3), np.uint8)),
'class': np.random.randint(10),
}
out.write(sample)
Upload your streaming dataset to the cloud storage of your choice (AWS, OCI, or GCP). Below is one example of uploading a directory to an S3 bucket using the AWS CLI.
$ aws s3 cp --recursive path-to-dataset s3://my-bucket/path-to-dataset
from torch.utils.data import DataLoader
from streaming import StreamingDataset
# Remote path where full dataset is persistently stored
remote = 's3://my-bucket/path-to-dataset'
# Local working dir where dataset is cached during operation
local = '/tmp/path-to-dataset'
# Create streaming dataset
dataset = StreamingDataset(local=local, remote=remote, shuffle=True)
# Let's see what is in sample #1337...
sample = dataset[1337]
img = sample['image']
cls = sample['class']
# Create PyTorch DataLoader
dataloader = DataLoader(dataset)
Getting started guides, examples, API references, and other useful information can be found in our docs.
We have end-to-end tutorials for training a model on:
We also have starter code for the following popular datasets, which can be found in the streaming
directory:
Dataset | Task | Read | Write |
---|---|---|---|
LAION-400M | Text and image | Read | Write |
WebVid | Text and video | Read | Write |
C4 | Text | Read | Write |
EnWiki | Text | Read | Write |
Pile | Text | Read | Write |
ADE20K | Image segmentation | Read | Write |
CIFAR10 | Image classification | Read | Write |
COCO | Image classification | Read | Write |
ImageNet | Image classification | Read | Write |
To start training on these datasets:
convert
directory.For example:
$ python -m streaming.multimodal.convert.webvid --in <CSV file> --out <MDS output directory>
from streaming.multimodal import StreamingInsideWebVid
dataset = StreamingInsideWebVid(local=local, remote=remote, shuffle=True)
Easily experiment with dataset mixtures with Stream
. Dataset sampling can be controlled in relative (proportion) or absolute (repeat or samples terms). During streaming, the different datasets are streamed, shuffled, and mixed seamlessly just-in-time.
# mix C4, github code, and internal datasets
streams = [
Stream(remote='s3://datasets/c4', proportion=0.4),
Stream(remote='s3://datasets/github', proportion=0.1),
Stream(remote='gcs://datasets/my_internal', proportion=0.5),
]
dataset = StreamingDataset(
streams=streams,
samples_per_epoch=1e8,
)
A unique feature of our solution: samples are in the same order regardless of the number of GPUs, nodes, or CPU workers. This makes it easier to:
See the figure below — training a model on 1, 8, 16, 32, or 64 GPUs yields the exact same loss curve (up to the limitations of floating point math!)
It can be expensive — and annoying — to wait for your job to resume while your dataloader spins after a hardware failure or loss spike. Thanks to our deterministic sample ordering, StreamingDataset lets you resume training in seconds, not hours, in the middle of a long training run.
Minimizing resumption latency can save thousands of dollars in egress fees and idle GPU compute time compared to existing solutions.
Our MDS format cuts extraneous work to the bone, resulting in ultra-low sample latency and higher throughput compared to alternatives for workloads bottlenecked by the dataloader.
Tool | Throughput |
---|---|
StreamingDataset | ~19000 img/sec |
ImageFolder | ~18000 img/sec |
WebDataset | ~16000 img/sec |
Results shown are from ImageNet + ResNet-50 training, collected over 5 repetitions after the data is cached after the first epoch.
Model convergence from using StreamingDataset is just as good as using local disk, thanks to our shuffling algorithm.
Below are results from ImageNet + ResNet-50 training, collected over 5 repetitions.
Tool | Top-1 Accuracy |
---|---|
StreamingDataset | 76.51% +/- 0.09 |
ImageFolder | 76.57% +/- 0.10 |
WebDataset | 76.23% +/- 0.17 |
StreamingDataset shuffles across all samples assigned to a node, whereas alternative solutions only shuffle samples in a smaller pool (within a single process). Shuffling across a wider pool spreads out adjacent samples more. In addition, our shuffling algorithm minimizes dropped samples. We have found both of these shuffling features advantageous for model convergence.
Access the data you need when you need it.
Even if a sample isn’t downloaded yet, you can access dataset[i]
to get sample i
. The download will kick off immediately and the result will be returned when it’s done - similar to a map-style PyTorch dataset with samples numbered sequentially and accessible in any order.
dataset = StreamingDataset(...)
sample = dataset[19543]
StreamingDataset will happily iterate over any number of samples. You do not have to forever delete samples so that the dataset is divisible over a baked-in number of devices. Instead, each epoch a different selection of samples are repeated (none dropped) so that each device processes the same count.
dataset = StreamingDataset(...)
dl = DataLoader(dataset, num_workers=...)
Dynamically delete least recently used shards in order to keep disk usage under a specified limit. This is enabled by setting the StreamingDataset argument cache_limit
. See the shuffling guide for more details.
dataset = StreamingDataset(
cache_limit='100gb',
...
)
Here are some projects and experiments that used StreamingDataset. Got something to add? Email mcomm@databricks.com or join our Community Slack.
We welcome any contributions, pull requests, or issues.
To start contributing, see our Contributing page.
P.S.: We're hiring!
If you like this project, give us a star ⭐ and check out our other projects:
@misc{mosaicml2022streaming,
author = {The Mosaic ML Team},
title = {streaming},
year = {2022},
howpublished = {\url{<https://github.com/mosaicml/streaming/>}},
}