NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.18k stars 622 forks source link

Asynchronous prefetching #5583

Open quanvuong opened 4 months ago

quanvuong commented 4 months ago

Is this a new feature, an improvement, or a change to existing functionality?

New Feature

How would you describe the priority of this feature request

Nice to have (e.g. Adoption is possible, the feature will enhance the use-case even more).

Please provide a clear description of problem this feature solves

Instantiating the dali dataloader for jax takes a long time, because of prefetching.

Feature Description

It would be nice to have an asynchronous prefetching feature, so I can interleave jitting the model and the prefetching operations.

Describe your ideal solution

Have two function

def start_async_prefetch(self):

def block_till_ready_async_prefetch():

With these two functions, I can control when prefetching happens

Describe any alternatives you have considered

No response

Additional context

No response

Check for duplicates

JanuszL commented 3 months ago

Hi @quanvuong,

Thank you for reaching out. Currently, the exact functionality you are asking for is not fully exposed. What you can try out is:

quanvuong commented 3 months ago

That does seem to speed things up by quite a bit.

Instantiating the data iterator is still quite slow because I'm using the integration with jax, which requires starting the worker pool with "spawn" (on a 8 H100 node, starting the all the worker pools can takes 20 minutes).

Do you have advises on how to improve the speed here?

JanuszL commented 3 months ago

Hi @quanvuong,

(on a 8 H100 node, starting the all the worker pools can takes 20 minutes).

This is surprising and not expected. Can you share a self contained repro code we can run on our end for debugging?

quanvuong commented 3 months ago

I'm working on the self contained repo, in the mean time, I have narrowed down to these lines that are the slow operations.

Specifically, going from s0 to s1 takes 80 seconds (in nvidia/dali/_multiproc/pool.py)

    def _start_processes(self, mp, start_method, write_sockets):
        try:
            import time 
            s0 = time.time()
            for process in self._processes:
                process.start()
            s1 = time.time()
            task_queues = [
                worker_context.dedicated_task_queue
                for worker_context in self._workers_contexts
                if worker_context.dedicated_task_queue is not None
            ]
            if self._general_task_queue is not None:
                task_queues.append(self._general_task_queue)
            self._observer = Observer(mp, self._processes, task_queues, self._result_queue)
            s3 = time.time()
            if start_method != "fork":
                # NOTE when making any changes here, make sure to reflect them in the worker
                # process, so that it sets received handles to objects in the same order
                self._send_queue_handles(write_sockets)
                self._send_shm_handles(write_sockets)
            s4 = time.time()
            self._sync_initialized_workers()
            s5 = time.time()
            print(f"_start_processes: s1-s0 {s1-s0}")
            print(f"_start_processes: s3-s1 {s3-s1}")
            print(f"_start_processes: s4-s3 {s4-s3}")
            print(f"_start_processes: s5-s4 {s5-s4}")
        except:  # noqa: E722
            if self._observer is not None:
                self._observer.close()
                self._observer = None
            else:
                for proc in self._processes:
                    if proc.is_alive():
                        proc.terminate()
                for proc in self._processes:
                    if proc.pid is not None:
                        proc.join()
            raise
stiepan commented 3 months ago

Hi @quanvuong,

Because you are referring to nvidia/dali/_multiproc/pool.py, I assume you are using parallel external source in your iterator. It's a bit of a guess without more details on the code you wrote, but the place you pointed to will take more time when the callback/source you are passing to external source is heavy. At that point, Python multiprocess package passes serialized callbacks to the workers, so the bigger the serialized object is, the longer time it will take to start the processes.

If that's the case, you can check https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/parallel_external_source.html#Serialization-and-heavy-setup to try to get around, by making sure that the seriallized object is lighter.

quanvuong commented 3 months ago

Yes I am using parallel external source in my iterator (with jax integration).

I have moved heavy set up to get_state as recommended, and that reduces the time taken to instantiate the data iterator by 10-15%, but it is still quite slow. Any advise? Is there a profiler that I can use?

We are running on 8 H100 nodes, and instantiating the data iterator takes more than 10 minutes (about 2 minutes per gpu).

JanuszL commented 3 months ago

Hi @quanvuong,

Can you bisect which part of the external source callback is slow (start with the callback that just calls np.ones for example and then gradually extend it adding more logic/imports)?

stiepan commented 3 months ago

Hi @quanvuong,

To make sure this the serialization is no longer the main factor contributing to the start-up time, you could use custom pickler that wraps the pickle and provides you with some more information, like the size of the callback once it is serialized.

import pickle
from nvidia.dali import pipeline_def, fn

from source import Cb

class PeekPickler:

    @classmethod
    def loads(cls, payload):
        return pickle.loads(payload)

    @classmethod
    def dumps(cls, obj):
        payload = pickle.dumps(obj)
        print("The payload size for the callback: ", obj, len(payload))
        return payload

@pipeline_def(
    batch_size=4,
    device_id=0,
    num_threads=4,
    py_num_workers=4,
    py_start_method="spawn",
    py_callback_pickler=PeekPickler
)
def pipeline():
    return fn.external_source(Cb(1024), parallel=True, batch=False)

Another thing that may contribute to the total start-up time of the workers (although I would expect it to show as s5-s4 in the snippet you provided) is the time Python takes to setup all the imports and globals in the worker processes. Note, the main entrypoint file will be loaded and setup in the worker process, including recursively processing imports. It may help to define the callback/source in a separate file and make sure any heavy setup in the entrypoint file (along with any imports not needed for the callback) are protected with if __name__ == "__main__": statement.