NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.09k stars 615 forks source link

Multiple sharding policy in plugin.jax.data_iterator #5535

Open sali1997s opened 3 months ago

sali1997s commented 3 months ago

Describe the question.

Is there a way to set jax sharding for each output separately plugin.jax.data_iterator? For example, I have pipeline, that has 2 outputs. I want first output to be PartitionSpec(‘batch’, ‘model’) and the second to be PartitionSpec(‘batch, None) or PartitionSpec(‘batch’)?

Check for duplicates

awolant commented 3 months ago

Hello @sali1997s

Thanks for the question. Currently, something like this not supported unfortunately. This enhancement is in our TODO list for JAX integration.

Could you tell more about your use case and how would you need this to work? Especially, how do you map this to map on devices? Do you need both CPU and the GPU? I am asking because with DALI pipelines working on particular GPU there are some design and performance considerations for this feature and we would like the input from the users to influence these decisions. Thanks!

sali1997s commented 3 months ago

Thank you, for answering, @awolant! Sorry, I was thinking about my task deeper, and came to conclusion that partitioning data over batch fully covers my needs. I thought, i need more control over partitioning, but i don't need it currently.

But i've found that dataloader workes only in Data Parallel training, it currently doesn't support model-parallism inside. Here is a minimal reproducable example. By changing device_mesh to mesh_utils.create_device_mesh((4, 2)) it fails.

Also i've got question about @data_iterator (size param) and external source interaction. In case number of samples is divisible by shard size it works as supposed. But in other case it fails with WARNING:root:DALI iterator does not support resetting while epoch is not finished. Ignoring.... And doesn't go for second epoch iteration. Is there i can do something with it?

from nvidia.dali.plugin.jax import data_iterator
from jax.experimental import mesh_utils
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import numpy as np

GLOBAL_BATCH_SIZE = 64

class DataSorceCallable:
    def __init__(self, batch_size, seed, shard_id, num_shards):
        self.rng = np.random.default_rng(seed=seed)
        self.batch_size = batch_size

        self.files = np.random.rand(GLOBAL_BATCH_SIZE * 10, 4).astype(np.float32)

        self.shard_id = shard_id
        self.num_shards = num_shards

        self.shard_size = len(self.files) // num_shards
        self.shard_offset = self.shard_size * shard_id

        # If the shard size is not divisible by the batch size, the last incomplete batch
        # will be omitted.
        self.full_iterations = self.shard_size // batch_size 
        # print(self.full_iterations, self.shard_size, batch_size, len(self.files))
        self.perm = None
        self.last_seen_epoch = (
            None  # so that we don't have to recompute the `self.perm` for every sample
        )
    def __call__(self, sample_info):
        if sample_info.iteration >= self.full_iterations:
            raise StopIteration()
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx).permutation( 
                len(self.files)
            )

        sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]

        return self.files[sample_idx, :]

if __name__ == "__main__":
    device_mesh = mesh_utils.create_device_mesh((8, 1))
    mesh = Mesh(device_mesh, axis_names=("batch",'model'))
    sharding = NamedSharding(mesh, PartitionSpec("batch",))

    @data_iterator(output_map=['out'], sharding=sharding, size = GLOBAL_BATCH_SIZE * 10, prepare_first_batch = False)
    def callable_pipeline(num_shards, shard_id):
        out, = fn.external_source(
            source=DataSorceCallable(GLOBAL_BATCH_SIZE//num_shards, num_shards=num_shards, shard_id=shard_id, seed=42),
            num_outputs=1,
            batch=False,
            # parallel=True,
            dtype=[types.FLOAT],
        )
        return out.gpu()

    dataloader = callable_pipeline(batch_size = GLOBAL_BATCH_SIZE)

    for el in dataloader:
        print(el['out'].sharding)
awolant commented 2 months ago

Thanks for the reproduction. This is definitely a feature that could be added to DALI JAX support to enhance in functionality. For the first version of this integrating layer we focused only on the most common and simple cases.

When it comes to your question about external source, unfortunately, right now there is no way to do something like this. As I said, for this first version we wanted it to work in the most common and stable case.

In your use case, how would you expect this to work? I am asking just to get feedback about possible improvements to the JAX integration? Would you like for the missing samples to be filled/duplicated somehow?