ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.13k stars 5.61k forks source link

[core][experimental] Support broadcast NCCL ops in accelerated DAG #45308

Open stephanie-wang opened 4 months ago

stephanie-wang commented 4 months ago

Description

When the same GPU tensor is sent to multiple readers, we should use ncclBroadcast under the hood to reduce transfer time.

Use case

No response

rkooo567 commented 1 month ago

p2 for now (until we need more perf features)

jeffreyjeffreywang commented 1 week ago

I'm interested in this task and would love to work on it.

with InputNode() as inp:
    dag = sender.send.bind(shape, dtype, inp)
    dag = dag.with_type_hint(TorchTensorType(shape, dtype))
    o1 = receiver1.recv.bind(dag)
    o2 = receiver2.recv.bind(dag)
    dag = MultiOutputNode([o1, o2])

This should call ncclBroadcast under the hood as there are multiple receivers getting a tensor sent by a single sender. @stephanie-wang, is my understanding accurate?

rkooo567 commented 1 week ago

actually this one I want to hold a little bit before we have more clear design. The reason is that broadcast is not always working well if all downstream tasks are not running at the same. For example;

with InputNode() as inp:
    dag = sender.send.bind(shape, dtype, inp)
    dag = dag.with_type_hint(TorchTensorType(shape, dtype))
    dag2 = receiver1.long_running.bind(inp)
    o1 = receiver1.recv.bind(dag)
    o2 = receiver2.recv.bind(dag)
    dag = MultiOutputNode([o1, o2, dag2])

In this case, receiver_1.recv starts after receiver1.long_running finishes, and because it is broadcasting receiver2.recv should wait, which is different from current semantics. I think we need more refined heuristic to support this instead of just using broadcasting for all cases

jeffreyjeffreywang commented 1 week ago

Got it, that makes a lot of sense. Do we have an estimated timeline for completing the design? Please let me know if there's anything I can help with!

rkooo567 commented 1 week ago

I think following up with multi output ref (multi ray.get and skip deseriailzation) are good candidates!!

rkooo567 commented 1 week ago

For this issue, if you'd like to take it, I think we need 2 followups;

jeffreyjeffreywang commented 1 week ago

Sounds good, I'll get started on the deserialization problem and get back to this issue after doing a bit more research.