[RFC] Polylithic: Enabling multi-threaded DataLoading through non-monolithic parallelism #1334

Open andrewkho opened 2 days ago

andrewkho commented 2 days ago

This RFC was re-created due to a problem with the original. Summary of comments from previous issue below.

šŸš€ The feature

TL;DR - We want to lean into modular Multi-Threading/Multi-Processing instead of the current monolithic Multi-Processing, and steer users away from the monolithic Dataset parallelism approach towards composable DataSources, and composable IterableDatasets for pre-proc operations, with parallelism configured within each operation. This will enable multi-threaded dataloading (with NoGIL support), auto-tunable parallelism, torch.compilable and GPU enabled preproc operations, more efficient loading of mixed-modalities, and composable dataloading and pre-proc graphs.

Motivation, pitch

Working name for the project: Polylithic (non-monolithic)
Where it will live: torchdata

Multimodal DataLoading is here and doesnā€™t support it well

Multi-Modal LLMs are here. Tasks like fine-tuning, alignment, and distillation will require multi-modal dataloading for our users. LLM training often requires reading from 10s-100s of multi-modal datasets, tokenizing them, and packing them into a ā€œtoken-bufferā€ where tokens from individual datasets are shuffled and combined into training examples for the model.

Audio, Image, and Video datasets may also require heavy-weight decoding operations to be performed before tokenization, and the difference in the data sizes between text, image, and video may be orders of magnitude. GPU decoding of images and video is an option for users as well, and libraries like Nvidia DALI will compile the entire pre-proc pipeline into GPU operations, minimizing the overhead of transfers between CPU and GPU memory.

torch.utils.dataā€™s Dataset and DataLoader abstractions are extremely popular with users, however they are not well equipped to handle MultiModal DataLoading and accelerated pre-proc, because of the monolithic, black-box way in which it treated parallelism with multiprocessing; ie running GPU Preproc under multiprocessing is not currently realistic. While the abstractions are extremely flexible and very easy to experiment with, users are often required to write bespoke classes to create pre-proc pipelines, handle data sharding and combine multiple datasets. Optimizing is also a challenge because of the lack of control in parallelism.

Existing Context and definitions contains the following abstractions today:

# Example usage of, Sampler, and Dataset, with multiprocess parallelism
dl =, maybe_my_sampler, batch_size, multiprocessing_num_workers)
for batch in dl:
  # model forward/backward

ā€œMonolithicā€ parallelism

Currently users have a single lever to control parallelism, num_workers. When num_workers > 0, the DataLoader creates background processes and holds a copy of the entire Dataset object in process memory, treating it as a ā€œmonolithicā€ object to be parallelized.

Consider the scenario in the figure below, where a user has defined an iterable dataset which combines two text datasets and one image dataset. There is no parallelism in this example.


Now consider the common case when only the image-decoding and tokenization is a bottleneck causing GPU Starvation. With todayā€™s tooling, users simply increase dataloader num_workers > 1. The image below depicts how this is done today, by treating the entire IterableDataset as a monolith that is forked/spawned to another process.


Pain-points with Monolithic Parallelism for Multi-Modal LLM training

Multimodal data loading may require different levels of parallelism for different modalities, e.g. text tokenization may require only a single worker, while image decoding may benefit from 4+. The ā€œmonolithicā€ approach needlessly parallelizes operators that donā€™t need them, increasing memory and CPU utilization for things like token buffers. Tuning parallelism for performance is difficult as there is only one knob (num_workers) available.

Enabling GPU-PreProc pipelines (see Nvidia-DALI) may improve total training throughput for many users, however combining multiprocessing (eg to parallelize blob-fetching) and GPU PreProc (eg for image decoding / cropping) in the same Dataset is not currently possible.

Tensor and Pipeline parallelism offer opportunities for more efficient and more resilient/correct dataloading, however the current is not well equipped to take advantage of this.

As we gradually move to a NoGIL world and multi-threading becomes a viable method to parallelize, the current monolithic approach requires the entire Dataset (dataloading and preproc) and its dependencies to be thread-safe, which may cause problems with adoption.

We also suffer from the usual multi-processing pain points:

A granular parallelism approach

To fix the monolithic parallelism problem, we want to introduce abstractions and tooling that expose more granular parallelism controls to users. This implies a solution where users construct their dataloading and pre-proc pipelines by defining and stitching together datasource and pre-proc nodes into a graph, in a similar fashion to and datapipes, with data passing between the nodes. The root of the graph is the node which produces batches that are passed to the model. The leaves are data-sources which produce data by reading from local disk, remote storage, or eg random number generators. Intermediate nodes may transform data, perform pre-fetching, combine data from multiple nodes, perform ā€œenrichmentsā€ by eg fetching images from blob stores, perform decoding, schedule GPU operations etc.

Requirements and Constraints

To adequately support Multi Modal LLM training for PyTorch users, address the above pain points, and give us the best chance for wide-adoption, we want our solution to meet the following requirements and constraints:

class DatasetSampler:
  def __init__(self, sources: List[iterables]):
    self.sources = sources

  def __iter__(self):
    self.base_iters = [itertools.cycle(iter(x)) for x in self.sources]
    n = len(self.base_iters) 
    while True: 
      ds_idx = random(n, self.sampling_weights)
      yield next(self.base_iters[ds_idx])

How will we achieve this/what will we build? Plan of Record

We will introduce a new base class, (working name) say class PolylithicNode( Nodes in the graph will be instances of subclasses of PolylithicNode. Nodes will define a .iterator() method instead of overriding __iter__(). This is inspired by nn.Moduleā€™s implementation where users define .forward() instead of __call__. This will allow PolylithicNode to instantiate user-defined iterators and wrap them, insert queues for pipeline-parallelism, and measure latency. For backwards compatibility, weā€™ll provide a wrapper which takes an existing IterableDataset. Users can compose their datasets by composing PolylithicNodes (ie through iter() and next()).

Example of composing iterable datasets to create a multimodal dataloader. [Note that we are open to ideas on syntactical sugar]

from torchdata.polylithic.nodes import PolylithicNode, Batcher, MultiThreadedMapper, PinMemory, Prefetcher, AcceleratedMapper # Note that all of these classes subclass PolylithicNode

# Note: PolylithicNode is an abstract class which provides common code for state_dict, graph traversal, autotuning, #   error propogation, etc.
# class PolylithicNode( ...
#   def __iter__(self):  # PolylithicNode is still an IterableDataset
#     ...

# Some existing IterableDataset, perhaps generated through eg HuggingFace
class MyIterableDataset(
  def __init__(self, json_l_file):
    self.json_l_file = json_l_file
  def __iter__(self):
    while True: # Loop forever
      with open(self.json_l_file, "r") as f:
        for line in f.readlines():
          yield json.loads(line)

# Define a Token Packer
class MyTokenPacker(PolylithicNode):
  def __init__(self, tokens_per_sample: int, sources: List[PolylithicNode], weights: List[float]):
    self.n = tokens_per_sample
    self.sources = sources
    self.weights = weights

  def iterator(self):
    self.source_iters = [iter(src) for src in self.sources]
    sample = []
    while True:
      while len(sample) < self.n:
        src_idx = weighted_sample_int(len(weights), self.weights)
        tokens = next(self.source_iters[src_idx])["tokens"]
      yield sample[:self.n]
      sample = sample[self.n:]

# Set up Tokenizer UDFs
def tokenize(data):
  data["tokens"] = Tokenizer()(data["text"])

def tokenize_img_and_text(data):
  data["tokens"] = DecodeAndTokenize()(data["image"]) + Tokenizer()(data["caption"])

# Set up text reader
text_src = PolylithicNode.from_iterable(MyIterableDataset("text_data.jsonl"))
text_src = MultiThreadedMapper(text_src, udf=tokenize, num_workers="AUTOTUNE")

# Set up Text and Image dataset, with GPU Decoding 
img_src = PolylithicNode.from_iterable(MyIterableDataset("img_caption_data.jsonl"))
img_src = Mapper(img_src, udf=GpuImageDecoder(...)) # single threaded in main process
img_src = MultiThreadedMapper(img_src, udf=tokenize_img_and_text, num_workers="AUTOTUNE")
# Rest of pipeline
node = MyTokenPacker([img_src, text_src], [0.25, 0.75])
node = Batcher(node, batch_size)
node = PinMemory(node)
node = Prefetch(node, 2)

for tokens in node:

More complex diagram


What about DataPipes and DL v2?

DataPipes and DL v2 were designed to address issues like composability, and there is a lot of value in what theyā€™ve built, however their parallelism and sharding structure is still based on a monolithic approach (eg plug a datapipe into DL v1, or DL v2 + multiprocess reading service). They required migration/rewrite of datasets with often no improvement in performance, identifying dataloading-preproc bottlenecks was a challenge, and shuffling/sharding pain points werenā€™t adequately addressed.

The proposed approach improves upon DataPipes + DLv2 in the following ways:

We want to maintain the composable aspects of datapipes, the eager-execution, and continue our partnerships with storage and cloud providers (AWS, Azure, GCP) where they provide high-performance clients, share customer pain points, and provide recommended solutions and examples to their users.


Additional context

andrewkho commented 2 days ago

Some comments / discussions from earlier:

knoopx commented 22 hours ago

I was perfectly happy with datapipes, it provided me simple building blocks that allowed to optimize heavy-weight processes. I don't understand the need to kill them with no replacement but a promise of a better solution which addresses a completely different problem.

andrewkho commented 15 hours ago

Thanks @knoopx for the comment, is there something particular that you are doing with datapipes that wouldn't be possible with this proposal?

knoopx commented 7 hours ago

@andrewkho I mostly use iterable-style datapipes, I like the simplicity and being able to easily chain them together and defer execution. I use them for all sort of things, not just for ML stuff. Iterable datapipes feel like python-esque observables/streams/deferables/futures/promises to me. the problems the proposal tries to solve are novel, and I'm pretty sure I could accomplish the same things but imho this new api looks like a step backwards in developer experience and i'm not sure it will solve all the existing pitfalls (like "debuggability"), after-all parallelism is intrinsically a hard problem (plus python gotchas) and adding more lower-level abstractions won't make it easier for regular users. just hopping an alternative higher-level api comes later, after you figure out all the necessary building blocks.