Open quanvuong opened 4 months ago
Hi @quanvuong,
Thank you for reaching out. Currently, the exact functionality you are asking for is not fully exposed. What you can try out is:
prepare_first_batch=False
so there is no prefetching when it is createditerator._schedule_runs(release_outputs=False)
to start prefetching. asynchroniously and then call just next()
to wait for the data.
Please let us know if that works for your use case.That does seem to speed things up by quite a bit.
Instantiating the data iterator is still quite slow because I'm using the integration with jax, which requires starting the worker pool with "spawn" (on a 8 H100 node, starting the all the worker pools can takes 20 minutes).
Do you have advises on how to improve the speed here?
Hi @quanvuong,
(on a 8 H100 node, starting the all the worker pools can takes 20 minutes).
This is surprising and not expected. Can you share a self contained repro code we can run on our end for debugging?
I'm working on the self contained repo, in the mean time, I have narrowed down to these lines that are the slow operations.
Specifically, going from s0 to s1 takes 80 seconds (in nvidia/dali/_multiproc/pool.py)
def _start_processes(self, mp, start_method, write_sockets):
try:
import time
s0 = time.time()
for process in self._processes:
process.start()
s1 = time.time()
task_queues = [
worker_context.dedicated_task_queue
for worker_context in self._workers_contexts
if worker_context.dedicated_task_queue is not None
]
if self._general_task_queue is not None:
task_queues.append(self._general_task_queue)
self._observer = Observer(mp, self._processes, task_queues, self._result_queue)
s3 = time.time()
if start_method != "fork":
# NOTE when making any changes here, make sure to reflect them in the worker
# process, so that it sets received handles to objects in the same order
self._send_queue_handles(write_sockets)
self._send_shm_handles(write_sockets)
s4 = time.time()
self._sync_initialized_workers()
s5 = time.time()
print(f"_start_processes: s1-s0 {s1-s0}")
print(f"_start_processes: s3-s1 {s3-s1}")
print(f"_start_processes: s4-s3 {s4-s3}")
print(f"_start_processes: s5-s4 {s5-s4}")
except: # noqa: E722
if self._observer is not None:
self._observer.close()
self._observer = None
else:
for proc in self._processes:
if proc.is_alive():
proc.terminate()
for proc in self._processes:
if proc.pid is not None:
proc.join()
raise
Hi @quanvuong,
Because you are referring to nvidia/dali/_multiproc/pool.py
, I assume you are using parallel external source in your iterator. It's a bit of a guess without more details on the code you wrote, but the place you pointed to will take more time when the callback/source you are passing to external source is heavy. At that point, Python multiprocess package passes serialized callbacks to the workers, so the bigger the serialized object is, the longer time it will take to start the processes.
If that's the case, you can check https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/parallel_external_source.html#Serialization-and-heavy-setup to try to get around, by making sure that the seriallized object is lighter.
Yes I am using parallel external source in my iterator (with jax integration).
I have moved heavy set up to get_state as recommended, and that reduces the time taken to instantiate the data iterator by 10-15%, but it is still quite slow. Any advise? Is there a profiler that I can use?
We are running on 8 H100 nodes, and instantiating the data iterator takes more than 10 minutes (about 2 minutes per gpu).
Hi @quanvuong,
Can you bisect which part of the external source callback is slow (start with the callback that just calls np.ones
for example and then gradually extend it adding more logic/imports)?
Hi @quanvuong,
To make sure this the serialization is no longer the main factor contributing to the start-up time, you could use custom pickler that wraps the pickle and provides you with some more information, like the size of the callback once it is serialized.
import pickle
from nvidia.dali import pipeline_def, fn
from source import Cb
class PeekPickler:
@classmethod
def loads(cls, payload):
return pickle.loads(payload)
@classmethod
def dumps(cls, obj):
payload = pickle.dumps(obj)
print("The payload size for the callback: ", obj, len(payload))
return payload
@pipeline_def(
batch_size=4,
device_id=0,
num_threads=4,
py_num_workers=4,
py_start_method="spawn",
py_callback_pickler=PeekPickler
)
def pipeline():
return fn.external_source(Cb(1024), parallel=True, batch=False)
Another thing that may contribute to the total start-up time of the workers (although I would expect it to show as s5-s4
in the snippet you provided) is the time Python takes to setup all the imports and globals in the worker processes. Note, the main entrypoint file will be loaded and setup in the worker process, including recursively processing imports. It may help to define the callback/source in a separate file and make sure any heavy setup in the entrypoint file (along with any imports not needed for the callback) are protected with if __name__ == "__main__":
statement.
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Nice to have (e.g. Adoption is possible, the feature will enhance the use-case even more).
Please provide a clear description of problem this feature solves
Instantiating the dali dataloader for jax takes a long time, because of prefetching.
Feature Description
It would be nice to have an asynchronous prefetching feature, so I can interleave jitting the model and the prefetching operations.
Describe your ideal solution
Have two function
def start_async_prefetch(self):
def block_till_ready_async_prefetch():
With these two functions, I can control when prefetching happens
Describe any alternatives you have considered
No response
Additional context
No response
Check for duplicates