Open kvablack opened 2 months ago
Hello @kvablack
thank you for creating the issue. The idea looks very interesting. Could you share some of your performance results? Also, could you tell us more about the pipeline you are trying to build? Is it very complex?
Sure thing. I just tested it with 8 GPUs:
The pipeline is not very complex. It loads from a parallel ExternalSource
and does image decoding. I suspect the overhead is from creating the workers (6 workers per GPU). I followed the recommendation in the documentation and put the heavy setup in the __setstate__
function of the ExternalSource
; however, there is some amount of data that I need to send to each worker, which I just measured at 84Mb when pickled.
Is this a new feature, an improvement, or a change to existing functionality?
Improvement
How would you describe the priority of this feature request
Should have (e.g. Adoption is possible, but the performance shortcomings make the solution inferior).
Please provide a clear description of problem this feature solves
I'm using DALI with the JAX plugin. From what I can tell, every plugin builds multiple pipelines in sequence.
Each pipeline contains a fairly beefy external source, so this process can take a very long time, especially with 8 GPUs.
Feature Description
I recently wrote my own code to bypass the JAX plugin, and I parallelized the building process with a thread pool:
This sped up initialization by a significant amount. Similarly, I parallelized running the pipelines:
This also sped up my maximum throughput by a modest amount, although the above code is a simplification -- I also do some other things after each run, like copying buffers into JAX memory. So I'm not sure if this would apply to every plugin.
It would be nice if this could be upstreamed into DALI, especially the building part!
Describe your ideal solution
See above
Describe any alternatives you have considered
No response
Additional context
No response
Check for duplicates