NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.09k stars 615 forks source link

How to split web dataset shards in parallel external source with worker id in DALI? #4892

Open chenyaofo opened 1 year ago

chenyaofo commented 1 year ago

Describe the question.

I seek to write external source to load webdatasets from S3. And I want to adopt the parallel external source to accelerate the loading.

Usually, the shards of webdatasets will be split based on both the rank id and the worker id (of multi-processing), to ensure there are no overlapping among the processes (an example implementation torchdata.datapipes.iter.ShardingFilter https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html)

So it is necessary to know the worker id to achieve this. In PyTorch, this can be achieved using the API torch.utils.data.get_worker_info https://pytorch.org/docs/stable/data.html#torch.utils.data.get_worker_info .

However, in DALI, there doesn't seem to be a similar API available to obtain the worker's ID. Despite searching through the documentation and issues, I have been unable to find a solution.

Check for duplicates

JanuszL commented 1 year ago

Hi @chenyaofo,

Thank you for reaching out. I'm sorry but the world configuration is not a part of the External Source environment. In this case, you need to provide it on your own, like in this exampe the batch size is provided to the ExternalInputCallable __init__ function.

stiepan commented 1 year ago

Hi @chenyaofo,

Can you tell us more on what loading pattern would you like to see? By splitting the webdatasets, do you mean that each worker process should have its subset of archives to read? Then, when each worker iterates over its archives, each should contribute a fixed number of samples to form the batch?

Below I assume that's the case.

Anserwing your question about the worker id, there is currently no API to learn it. The main use case in mind was external source with a callback that produces the samples based on an index of a sample within batch and epoch. There is no way of knowing beforehand which worker will be tasked to produce given sample in a batch: workers pick up new work whenver they are idle. This way, when some samples are easy and other difficult to load/prepare, while one workers processes the heavy sample, the rest can pick up the remaining work and overlap. Splittin dataset based on worker id under such circumstances would have no benefit. The disadvantage of those assumptions is that it is mostly suitable to map-style datasets, when the cost of random accessing a sample is negligable.

I think we may need to consider adding some API that allows to clearly express worker <-> samples affinity. But in the meantime, I think we can work around the limitation with iterables.

Iterables have inner state and thus they are not parallelized between multiple workers - each iterable gets a single worker that iterates over it. We can use it to get the sharding between workers as described in the first section. The idea is to have one external source for each shard. The remaining trick is to merge batches produced by different exteranl sources into a single batch - that can be done with conditional execution.

from nvidia.dali import fn, pipeline_def, types
import numpy as np
import dill

num_shards = 4
mini_batch_size = 5
batch_size = mini_batch_size * num_shards

class ShardedSource:

    def __init__(self, dataset_size, num_shards, shard_id, mini_batch_size, total_batch_size):
        self.dataset_size = dataset_size
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.mini_batch_size = mini_batch_size
        self.total_batch_size = total_batch_size
        self.shard_size = dataset_size // num_shards
        self.shard_begin = self.shard_size * shard_id
        self.shard_end = self.shard_size * (shard_id + 1)
        self.dummy_output = np.full((0, 0), 0, dtype=np.float32)
        self.reset()

    def reset(self):
        self.i = self.shard_begin

    def padded_to_full_batch(self, mini_batch):
        assert len(mini_batch) == self.mini_batch_size
        offset = self.shard_id * self.mini_batch_size
        padded_batch = [self.dummy_output for _ in range(self.total_batch_size)]
        for i, sample in enumerate(mini_batch):
            padded_batch[i + offset] = sample
        return padded_batch

    def __len__(self):
        return self.shard_end - self.shard_begin

    def __iter__(self):
        self.reset()
        return self

    def __next__(self):
        print(self, self.i)
        if self.i + self.mini_batch_size > self.shard_end:
            raise StopIteration
        batch = [np.full((3, 3), idx, dtype=np.float32) for idx in range(self.i, self.i + self.mini_batch_size)]
        self.i += self.mini_batch_size
        return self.padded_to_full_batch(batch)

def run_external_sources(dataset_size, num_shards, mini_batch_size, batch_size):
    shards = []
    for shard_idx in range(num_shards):
        sharded_source = ShardedSource(dataset_size, num_shards, shard_idx, mini_batch_size, batch_size)
        ext_sharded_source = fn.external_source(sharded_source, batch=True, parallel=True)
        shards.append(ext_sharded_source)
    return shards

def merge_mini_batches(mini_batch_size, mini_batches):
    # the sample_info.idx_in_batch goes over [0, 1, ..,  num_shards * mini_batch_size - 1]
    # (where the batch_size = num_shards * mini_batch_size) 
    # for the 0, 1, .., mini_batch - 1 indicies we pick the first mini_batch
    # for the mini_batch, ..., 2 * mini_batch - 1 the second mini_batch and so on
    mini_batch_idx = fn.external_source(
        lambda sample_info: np.array(
            sample_info.idx_in_batch // mini_batch_size,
            dtype=np.int32),
        batch=False)
    return _select_mini_batch(mini_batch_idx, mini_batches)

def _select_mini_batch(mini_batch_idx, mini_batches, i=0):
    if i >= len(mini_batches) - 1:
        return mini_batches[-1]
    if mini_batch_idx == i:
        return mini_batches[i]
    return _select_mini_batch(mini_batch_idx, mini_batches, i + 1)

@pipeline_def(
    # the size of all worker-shards combined
    batch_size=batch_size,
    # each multiprocessing worker will get its own iterator
    py_num_workers=num_shards,
    py_start_method="spawn",
    device_id=0,
    num_threads=4,
    enable_conditionals=True,
    py_callback_pickler=(dill, {'recurse': True}))
def pipeline():
    mini_batches = run_external_sources(200, num_shards, mini_batch_size, batch_size)
    full_batch = merge_mini_batches(mini_batch_size, mini_batches)
    return full_batch

p = pipeline()
p.build()
p.run()
chenyaofo commented 1 year ago

Hi, @stiepan, thanks for your patient reply.

By splitting the webdatasets, do you mean that each worker process should have its subset of archives to read? Then, when each worker iterates over its archives, each should contribute a fixed number of samples to form the batch?

Yes, it is what I want.

# for the 0, 1, .., mini_batch - 1 indicies we pick the first mini_batch # for the mini_batch, ..., 2 * mini_batch - 1 the second mini_batch and so on

I did not notice sample_info.idx_in_batch has the above rule. In this case, it is possible to dynamically calculate the corresponding subset of shards based on sample_info.idx_in_batch, as shown in your solution.

Furthermore, I found you mention that

There is no way of knowing beforehand which worker will be tasked to produce given sample in a batch: workers pick up new work whenver they are idle.

In this sense, a worker may be assigned different subsets of shards. For example, I run with world_size=2, each rank with external source (py_num_workers=2). I have 4 processes to load data in total. Assume I have 4 webdataset shards in S3, namely 0000.tar, 0001.tar, 0002.tar and 0003.tar. Assume num_shards=4, mini_batch_size=1, batch_size=1*4=4. In the first time, a worker get sample_info.idx_in_batch=0, based on your solution, so it will start to download the first shard 0000.tar from S3. In the next time, it may get sample_info.idx_in_batch=5, then it will start to download the second shard 0001.tar from S3. It would be very inefficient since the worker switches over the shards continually. In an ideal case, a worker should deal with a fixed subset of shards instead of varying one.

It would be better to provide an API to get worker id, as done in PyTorch. With this, it will be easy to pre-compute the subset of shards while init the external source.

stiepan commented 1 year ago

Hi @chenyaofo,

Let me clarify, the

There is no way of knowing beforehand which worker will be tasked to produce given sample in a batch: workers pick up new work whenver they are idle. refers to a single external source instance and a callback as a source.

I descirbed that case as a rationale for why we did not expose the worker id - as it cannot be reliably used to do any between workers sharding in the model we assumed.

An iterable passed as a source to ES works differently - a single instance will use a single worker. That was meant just as a way to run the iterable asynchornously to the main process, i.e. the single external source running in the single worker is expected to produce the whole batch. But we can workaround it to match your case and introduce worker level sharding. Instead of checking worker_id in the initialization, we just instantiate as many external sources as there are workers and pass each of them a concesutive integer. You can think of the integer as a worker id. (It is guaranteed that DALI will assign a different worker to each iterable as long as there are enough workers).

Then, each external source iterates in each own process over its own archive. The workaround we need is about the batch size: let us say a single pipeline uses two multiprocessing workers and we want two external sources to produce, in every iteration, 4 samples each. So the total batch size of the pipeline is 4 + 4 = 8. However, DALI requires that each external source produce the full batch, so both ES are expected to return 8 samples. The workaround is that each ES loads 4 samples from the archive and padds the batch with dummy empty sample to the 8 samples. For example, first ES produces [sample0, sample1, sample2, sample3, dummy, dummy, dummy, dummy] and the second [dummy, dummy, dummy, dummy, sample4, sample5, sample6, sample7].

The merge_mini_batches just composes a single batch sample0, sample1, sample2, sample3, sample4, sample5, sample6, sample7 by picking the actual samples and dropping the dummy ones.

Btw. The sample_info.idx_in_batch has no special meaning, if we returned it from the helper ES directly, in the presented example it would be just [0, 1, 2, 3, 4, 5, 6, 7]. When we divide it by the expected number of samples produced by the ES (i.e. 4 in our example), we get [0, 0, 0, 0, 1, 1, 1, 1], so for each sample we get the information from which external source we need to pick it.

It's a bit tricky, but in general should be performant and correct. In the meantime, I will add a task for us to look into similar use cases and maybe provide someting more out of the box solution.

chenyaofo commented 1 year ago

@stiepan Thanks you. I get the idea. The current parallel external source API is more suitable for map-style datasets. So the solution for iterable-style datasets looks a little bit complex and tricky. Hope there will be a more straightforward solution.

@JanuszL Thank you. I think my question has been solved.

stiepan commented 1 year ago

Hi @chenyaofo,

The current parallel external source API is more suitable for map-style datasets. So the solution for iterable-style datasets looks a little bit complex and tricky.

Exactly.

Thanks for reporting the issue. With your use case in mind, I will look into simplifying the iterables splitting.

In the meantime, I hope you can use the workaround solution. Please let us know if you face any problems running it.

jrcavani commented 12 months ago

Glad I was able to find this discussion. One stated advantage WebDataset is able to offer, is PB-scale parallel iterable-style streaming from cloud. The original implementation supports this, by dividing the data into (mostly) equal parts, and treating the data parts as "coarse indices". Because of that, each individual worker, by distributed rank and by multiprocessing worker is not able to communicate between each other, thus needing pre-determined work splitting. Each rank/mp worker gets its own portion of the data parts, and contribute to dataloader iterator.

Currently both original WebDataset implementation and TorchData support this splitting scheme, by knowing about rank/mp worker ID in advance. DALI has configurable shard_id and num_shards, but can only map to rank/world_size for now. If some iterable interface (similar to the non-parallel, basic veirson of ES) is exposed, so that user can have access to the worker_id (any uniqiue integer on the rank), and execute a splitting function, that would make it work.

jrcavani commented 12 months ago

@stiepan https://github.com/NVIDIA/DALI/issues/4892#issuecomment-1574835267 this is a really genius solution. I think the code as is addresses the problem within the multiprocessing worker scheme. e.g. 4 mp workers means num_shards=4, shard_id in [0,1,2,3]. To add the distributed sharding element, e.g. for world_size=8, 4 mp workers each means num_shards=8*4, shard_id in range(8*4). More care needs to be taken to split the source data parts by this aggregated sharding scheme, and take the right aggregated shard_id. If there is more time I can come back and put in some code.

Iterables have inner state and thus they are not parallelized between multiple workers - each iterable gets a single worker that iterates over it.

Could you clarify what would DALI do if it sees ES has __iter__, or __call__, and the parallel flag? I'm assuming

  1. If ES has __iter__, it's treated as a iterable, and DALI expects batch output (batch=True required). With parallel=True, it will only use one Python mp worker (no matter how many more one specifies at pipeline_def). However, if there are more than one ES, they get round-robin assigned to mp workers?
  2. If ES has __call__, it's treated as a map-style dataset, and one has the choice of batch=True/False. With parallel=True, it will be simply copied to each Python mp worker?
stiepan commented 12 months ago

Hi @jrcavani,

You are exactly right. The iterables must work in batches, a single iterator works in a single mp process (if parallel=True) and multiple iterators are assinged the available py_num_workers in round robin fashion. The callabales (functions and classes implementing the __call__ method) can work with both batch=True and batch=False. In any case, the callable will be copied to all the avialable workers. In batch mode it means a few consecutive batches (up to ExternalSource's prefetch_queue_depth) can be computed in parallel and in a sample mode, each sample from each prefetched batch will be processed separately. For callabales there's no affinity between workers and pieces of work, the worker just picks next sample or batch whenever it is idle.

If I understand it correctly, you are right about the wold size as well - you have to include the mp workers in the world_size and leave a few "ranks" for each dataloading pipeline.

What I suggested as a workaround here should be performant, but the padding and combing the batches into a single one is fairly laborious and just failry uncessaray. We'd like to get rid of that at some point, so any background on how you would like to use external source is very valuable input.

jrcavani commented 12 months ago

The iterables must work in batches, a single iterator works in a single mp process (if parallel=True) and multiple iterators are assinged the available py_num_workers in round robin fashion.

One advantage to iterables returning singular samples, is you can combine multiple ES's samples into a batch, increasing randomness, although likely only on the margin, assuming the data parts themselves are shuffled, samples window shuffled inside of the iterables.

For callabales there's no affinity between workers and pieces of work, the worker just picks next sample or batch whenever it is idle.

Just to play devil's advocate, if callables are able to be round-robin assigned instead of copied like iterables as ES, we can play the same one ES per mp worker trick, and just use the sample_info as a counter. PyTorch's Dataloader uses this pattern: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py#L283 https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/fetch.py#L24C39

If I understand it correctly, you are right about the wold size as well - you have to include the mp workers in the world_size and leave a few "ranks" for each dataloading pipeline.

WebDataset has a graceful pipeline construction that handles distributed splitting, reaching into the distributed module. To make it work with this pattern, I could do

def split_by_worker_wrap(worker_id, num_workers):
    def split_by_worker(src):
        if num_workers > 1:
            for s in islice(src, worker_id, None, num_workers):
                yield s
        else:
            for s in src:
                yield s

    return split_by_worker
        self.dataset = wds.DataPipeline(
            wds.SimpleShardList(urls=data_source_path),
            wds.detshuffle(),
            wds.split_by_node,
            split_by_worker_wrap(worker_id, num_workers),
            wds.tarfile_samples,
            wds.shuffle(32768),
            wds.rename(img="jpg", label="cls"),
            wds.map_dict(jpg=lambda x: np.frombuffer(x), label=int),
            wds.to_tuple("img", "label"),
            wds.batched(per_worker_batch_size, partial=False)
        )

And much of the worker shard definitions you gave can remain the same.

We'd like to get rid of that at some point, so any background on how you would like to use external source is very valuable input.

With TB/PB scale data loading, one would want to avoid downloading first, but simply stream data as training goes forward. So iterable-style sources are more popular than map-style ones for this use case. In order to load data in parallel, data need to be parted, and splitting data parts without an external centralized queue needs all this finagling.

Having an external queue solves a lot of this problem though, because it separates the data source fetching from the augmentation pipeline. The ES configuration would be much simpler. However it's extra work bringing up a performant external queue, and feed data through there. One would probably benefit from Ray Data.