ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.05k stars 5.59k forks source link

[core][adag] aDAGs with multiple outputs should allow getting them one at a time #46908

Closed stephanie-wang closed 4 days ago

stephanie-wang commented 1 month ago

Description

Right now, if you use MultiOutputNode to wrap a DAG's outputs, you get back a single CompiledDAGRef that points to a list of all of the outputs. That means that even if you only need one of the outputs, you have to get and deserialize all of them at the same time. This API is also a bit strange compared to Ray's usual API, which lets you get outputs individually if a task has multiple returns.

We should update the aDAG API to return one CompiledDAGRef per output, something like this:

with InputNode() as inp:
  dag = MultiOutputNode([a.foo.bind(inp), b.foo.bind(inp)])

dag = dag.experimental_compile()
# Now returns two refs, one per output in MultiOutputNode.
ref1, ref2 = dag.execute(1)

Use case

One use case is in vLLM tensor parallelism. Since all shards return the same results, we only need to get results from one of the workers.

stephanie-wang commented 1 month ago

Marking as beta release since it requires an API change.

rkooo567 commented 1 month ago

is https://github.com/ray-project/ray/issues/46909 necessary if we support this? maybe users can just do ray.get(ref1) and delete ref2 to avoid deser?

jeffreyjeffreywang commented 1 month ago

Hi @stephanie-wang, I'm currently working on resolving this issue. Here's how I repro the issue. Please let me know if my assumptions are correct. I also plan to take care of cases where async IO is involved. Thank you! 😄

import ray
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.compiled_dag_ref import CompiledDAGRef

@ray.remote
class Foo:
    def __init__(self):
        self.value = 0

    def increment(self, num):
        self.value += num
        return self.value

@ray.remote
class Bar:
    def __init__(self):
        self.value = 0

    def decrement(self, num):
        self.value -= num
        return self.value

if __name__ == '__main__':
    foo = Foo.remote()
    bar = Bar.remote()

    with InputNode() as inp:
        dag = MultiOutputNode([foo.increment.bind(inp), bar.decrement.bind(inp)])

    dag = dag.experimental_compile()

    refs = dag.execute(1)
    assert isinstance(refs, list)
    for ref in refs:
        assert isinstance(ref, CompiledDAGRef)

    ref1, ref2 = dag.execute(1)

    assert ref1.get() == 2
    try:
        ref1.get()
    except ValueError:
        print('Should print this line: CompiledDAGRef avoids fetching the value with .get() multiple times.')

    assert ref2.get() == -2
stephanie-wang commented 1 month ago

Hi @stephanie-wang, I'm currently working on resolving this issue. Here's how I repro the issue. Please let me know if my assumptions are correct. I also plan to take care of cases where async IO is involved. Thank you! 😄

import ray
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.compiled_dag_ref import CompiledDAGRef

@ray.remote
class Foo:
    def __init__(self):
        self.value = 0

    def increment(self, num):
        self.value += num
        return self.value

@ray.remote
class Bar:
    def __init__(self):
        self.value = 0

    def decrement(self, num):
        self.value -= num
        return self.value

if __name__ == '__main__':
    foo = Foo.remote()
    bar = Bar.remote()

    with InputNode() as inp:
        dag = MultiOutputNode([foo.increment.bind(inp), bar.decrement.bind(inp)])

    dag = dag.experimental_compile()

    refs = dag.execute(1)
    assert isinstance(refs, list)
    for ref in refs:
        assert isinstance(ref, CompiledDAGRef)

    ref1, ref2 = dag.execute(1)

    assert ref1.get() == 2
    try:
        ref1.get()
    except ValueError:
        print('Should print this line: CompiledDAGRef avoids fetching the value with .get() multiple times.')

    assert ref2.get() == -2

Hey @jeffreyjeffreywang thanks for the contribution! I'll assign this issue to you. Yes, your repro looks good to me!

stephanie-wang commented 1 month ago

is #46909 necessary if we support this? maybe users can just do ray.get(ref1) and delete ref2 to avoid deser?

Yes, exactly, I think once we address this issue, #46909 should not require further API change. Deleting ref2 will acquire the shared memory buffer and then immediately release it, but right now I believe we also do a deserialization in between. Ideally we should skip that part and just have one C++ call to acquire and immediately release. So #46909 is just an implementation optimization, no API change needed.

stephanie-wang commented 1 month ago

@jeffreyjeffreywang how is this going? Anything we can help with?

jeffreyjeffreywang commented 1 month ago

Hi @stephanie-wang, thank you so much for asking! I was able to figure out the synchronous pieces and is now working on the asynchronous portion. I'll publish a PR as soon as I have the fix along with additional unit tests done. I realized it's marked as P0. Is there a tight timeline for this issue?

rkooo567 commented 1 month ago

No! P0 means we want to finish it by end of Sep (for aDAG tasks). So this is not on tight deadline!

jeffreyjeffreywang commented 1 month ago

Here is a repro script for the asynchronous case, and I was able to get it fixed. Now moving onto unit tests. 😄

import ray
import asyncio
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.compiled_dag_ref import CompiledDAGFuture

import time

@ray.remote
class AsyncFoo:
    def __init__(self):
        self.value = 0

    def increment(self, num):
        self.value += num
        time.sleep(1)
        return self.value

@ray.remote
class AsyncBar:
    def __init__(self):
        self.value = 0

    def decrement(self, num):
        self.value -= num
        time.sleep(1)
        return self.value

if __name__ == '__main__':
    foo = AsyncFoo.remote()
    bar = AsyncBar.remote()

    with InputNode() as inp:
        dag = MultiOutputNode([foo.increment.bind(inp), bar.decrement.bind(inp)])

    dag = dag.experimental_compile(enable_asyncio=True)
    loop = asyncio.get_event_loop()

    async def run():
        futs = await dag.execute_async(1)
        assert len(futs) == 2
        for fut in futs:
            assert isinstance(fut, CompiledDAGFuture)

        result = await futs[0]
        assert result == 1

        result = await futs[1]
        assert result == -1

        fut1 = await dag.execute_async(1)
        fut2 = await dag.execute_async(1)

        start = time.time()
        result = await fut2[0] # Executes 2 steps (2 execute_async)
        end = time.time()

        assert result == 3
        assert end - start - 2 < 0.1 # Executing 2 steps should take about 2 sec
        try:
            await fut2[0]
        except ValueError:
            print("Should print this line")

        start = time.time()
        result = await fut2[1]
        end = time.time()

        assert result == -3
        # Fetching results from previous steps and the current step but different output channel
        # shouldn't execute the DAG again. Therefore, the await time should be negligible.
        assert end - start < 0.1

        start = time.time()
        result = await fut1[0]
        end = time.time()

        assert result == 2
        assert end - start < 0.1

    loop.run_until_complete(run())
    dag.teardown()