Open andrewkho opened 2 days ago
Some comments / discussions from earlier:
I was perfectly happy with datapipes, it provided me simple building blocks that allowed to optimize heavy-weight processes. I don't understand the need to kill them with no replacement but a promise of a better solution which addresses a completely different problem.
Thanks @knoopx for the comment, is there something particular that you are doing with datapipes that wouldn't be possible with this proposal?
@andrewkho I mostly use iterable-style datapipes, I like the simplicity and being able to easily chain them together and defer execution. I use them for all sort of things, not just for ML stuff. Iterable datapipes feel like python-esque observables/streams/deferables/futures/promises to me. the problems the proposal tries to solve are novel, and I'm pretty sure I could accomplish the same things but imho this new api looks like a step backwards in developer experience and i'm not sure it will solve all the existing pitfalls (like "debuggability"), after-all parallelism is intrinsically a hard problem (plus python gotchas) and adding more lower-level abstractions won't make it easier for regular users. just hopping an alternative higher-level api comes later, after you figure out all the necessary building blocks.
This RFC was re-created due to a problem with the original. Summary of comments from previous issue below.
š The feature
TL;DR - We want to lean into modular Multi-Threading/Multi-Processing instead of the current monolithic Multi-Processing, and steer users away from the monolithic Dataset parallelism approach towards composable DataSources, and composable IterableDatasets for pre-proc operations, with parallelism configured within each operation. This will enable multi-threaded dataloading (with NoGIL support), auto-tunable parallelism, torch.compilable and GPU enabled preproc operations, more efficient loading of mixed-modalities, and composable dataloading and pre-proc graphs.
Motivation, pitch
Working name for the project: Polylithic (non-monolithic)
Where it will live: torchdata
Multimodal DataLoading is here and torch.utils.data doesnāt support it well
Multi-Modal LLMs are here. Tasks like fine-tuning, alignment, and distillation will require multi-modal dataloading for our users. LLM training often requires reading from 10s-100s of multi-modal datasets, tokenizing them, and packing them into a ātoken-bufferā where tokens from individual datasets are shuffled and combined into training examples for the model.
Audio, Image, and Video datasets may also require heavy-weight decoding operations to be performed before tokenization, and the difference in the data sizes between text, image, and video may be orders of magnitude. GPU decoding of images and video is an option for users as well, and libraries like Nvidia DALI will compile the entire pre-proc pipeline into GPU operations, minimizing the overhead of transfers between CPU and GPU memory.
torch.utils.dataās Dataset and DataLoader abstractions are extremely popular with users, however they are not well equipped to handle MultiModal DataLoading and accelerated pre-proc, because of the monolithic, black-box way in which it treated parallelism with multiprocessing; ie running GPU Preproc under multiprocessing is not currently realistic. While the abstractions are extremely flexible and very easy to experiment with, users are often required to write bespoke classes to create pre-proc pipelines, handle data sharding and combine multiple datasets. Optimizing is also a challenge because of the lack of control in parallelism.
Existing Context and definitions
Torch.utils.data contains the following abstractions today:
āMonolithicā parallelism
Currently users have a single lever to control parallelism, num_workers. When num_workers > 0, the DataLoader creates background processes and holds a copy of the entire Dataset object in process memory, treating it as a āmonolithicā object to be parallelized.
Consider the scenario in the figure below, where a user has defined an iterable dataset which combines two text datasets and one image dataset. There is no parallelism in this example.
Now consider the common case when only the image-decoding and tokenization is a bottleneck causing GPU Starvation. With todayās tooling, users simply increase dataloader num_workers > 1. The image below depicts how this is done today, by treating the entire IterableDataset as a monolith that is forked/spawned to another process.
Pain-points with Monolithic Parallelism for Multi-Modal LLM training
Multimodal data loading may require different levels of parallelism for different modalities, e.g. text tokenization may require only a single worker, while image decoding may benefit from 4+. The āmonolithicā approach needlessly parallelizes operators that donāt need them, increasing memory and CPU utilization for things like token buffers. Tuning parallelism for performance is difficult as there is only one knob (num_workers) available.
Enabling GPU-PreProc pipelines (see Nvidia-DALI) may improve total training throughput for many users, however combining multiprocessing (eg to parallelize blob-fetching) and GPU PreProc (eg for image decoding / cropping) in the same Dataset is not currently possible.
Tensor and Pipeline parallelism offer opportunities for more efficient and more resilient/correct dataloading, however the current torch.utils.data.DataLoader is not well equipped to take advantage of this.
As we gradually move to a NoGIL world and multi-threading becomes a viable method to parallelize, the current monolithic approach requires the entire Dataset (dataloading and preproc) and its dependencies to be thread-safe, which may cause problems with adoption.
We also suffer from the usual multi-processing pain points:
A granular parallelism approach
To fix the monolithic parallelism problem, we want to introduce abstractions and tooling that expose more granular parallelism controls to users. This implies a solution where users construct their dataloading and pre-proc pipelines by defining and stitching together datasource and pre-proc nodes into a graph, in a similar fashion to tf.data and datapipes, with data passing between the nodes. The root of the graph is the node which produces batches that are passed to the model. The leaves are data-sources which produce data by reading from local disk, remote storage, or eg random number generators. Intermediate nodes may transform data, perform pre-fetching, combine data from multiple nodes, perform āenrichmentsā by eg fetching images from blob stores, perform decoding, schedule GPU operations etc.
Requirements and Constraints
To adequately support Multi Modal LLM training for PyTorch users, address the above pain points, and give us the best chance for wide-adoption, we want our solution to meet the following requirements and constraints:
How will we achieve this/what will we build? Plan of Record
We will introduce a new base class, (working name) say class PolylithicNode(torch.utils.data.IterableDataset). Nodes in the graph will be instances of subclasses of PolylithicNode. Nodes will define a .iterator() method instead of overriding __iter__(). This is inspired by nn.Moduleās implementation where users define .forward() instead of __call__. This will allow PolylithicNode to instantiate user-defined iterators and wrap them, insert queues for pipeline-parallelism, and measure latency. For backwards compatibility, weāll provide a wrapper which takes an existing IterableDataset. Users can compose their datasets by composing PolylithicNodes (ie through iter() and next()).
Example of composing iterable datasets to create a multimodal dataloader. [Note that we are open to ideas on syntactical sugar]
More complex diagram
What about DataPipes and DL v2?
DataPipes and DL v2 were designed to address issues like composability, and there is a lot of value in what theyāve built, however their parallelism and sharding structure is still based on a monolithic approach (eg plug a datapipe into DL v1, or DL v2 + multiprocess reading service). They required migration/rewrite of datasets with often no improvement in performance, identifying dataloading-preproc bottlenecks was a challenge, and shuffling/sharding pain points werenāt adequately addressed.
The proposed approach improves upon DataPipes + DLv2 in the following ways:
We want to maintain the composable aspects of datapipes, the eager-execution, and continue our partnerships with storage and cloud providers (AWS, Azure, GCP) where they provide high-performance clients, share customer pain points, and provide recommended solutions and examples to their users.
Alternatives
No response
Additional context
No response