vxgmichel / aiostream

Generator-based operators for asynchronous iteration
http://aiostream.readthedocs.io
GNU General Public License v3.0
801 stars 34 forks source link

How to handle stream splitting #98

Open parkerbjur opened 6 months ago

parkerbjur commented 6 months ago

Hey! this isn't really an issue with the library but more of a request for help so sorry if this isn't the right place. I am trying to write a transcriber that will take in multiple different tracks and emit transcriptions based on the track. so the input events would look something like

@dataclass
class TrackChunk():
  track: str
  chunk: bytes 

I've already made a pipeable operator that can take a stream of bytes and output transcriptions, but I'm not sure of the best way to separate these streams and pass them to the appropriate transcriber. (I also have a NewTrack event but wasn't sure if that would be a distraction)

This seems to be a pattern I run into a lot where I have a stream that needs to be separated into multiple streams and a handler needs to be created when there is a new stream. Would love some insight into how to handle this. I can also provide any additional information or clarification. Thanks!

@pipable_operator
async def transcribe(source: AsyncIterable[bytes], ws: WebSocketClientProtocol):
  asyncio.create_task(ws.send(source))

  return (
    stream.iterate(ws.reader())
    | stream.map(json.loads)
    | pipe.filter(lambda data: data["type"] == "Results")
    | pipe.until(lambda data: data["is_final"])
    | pipe.map(lambda data: data["channel"]["alternatives"][0]["transcript"])
    | pipe.accumulate(operator.add, initializer="")
  )

@dataclass(frozen=True)
class TranscribeTrack():
  track: str
  transcriber_config: dict

@dataclass(frozen=True)
class TrackChunkAdded():
  track: str
  chunk: bytes

Events = Union[TranscribeTrack, TrackChunkAdded]
async def transciption_worker(source: AsyncIterable[Events]):
    events = pipe.iterate(source)
vxgmichel commented 6 months ago

Hmm interesting problem!

It turns out that aiostream can handle streams of streams using the advanced operators, so that's probably what you're after here. The missing part is the ability to split a stream into a steam of streams where items are forwarded depending on a given predicate.

Here's an example of a split operator that would do just that:

from typing import AsyncIterable, TypeVar, Callable, AsyncIterator

from aiostream import pipable_operator, stream, pipe
from aiostream.core import streamcontext, Streamer
from aiostream.aiter_utils import AsyncExitStack
from anyio import create_memory_object_stream, BrokenResourceError
from anyio.abc import ObjectSendStream

T = TypeVar("T")
K = TypeVar("K")

@pipable_operator
async def split(
    source: AsyncIterable[T], key_function: Callable[[T], K], max_buffer_size: float = 0
) -> AsyncIterator[tuple[K, Streamer[T]]]:
    mapping: dict[K, ObjectSendStream[T]] = {}
    async with AsyncExitStack() as stack:
        async with streamcontext(source) as source:
            async for chunk in source:
                key = key_function(chunk)
                if key not in mapping:
                    sender, receiver = create_memory_object_stream[T](
                        max_buffer_size=max_buffer_size
                    )
                    mapping[key] = await stack.enter_async_context(sender)
                    yield key, streamcontext(receiver)
                try:
                    await mapping[key].send(chunk)
                except BrokenResourceError:
                    pass

Note how it uses a key function to tell where each produced item belongs. Here's an example of this operator being used:

@pytest.mark.asyncio
async def test_split():
    def is_even(x: int) -> bool:
        return x % 2 == 0

    def split_stream(
        key: bool, stream: Streamer[int], *_
    ) -> AsyncIterable[int | list[int]]:
        match key:
            case True:
                return stream | pipe.accumulate(initializer=0) | pipe.takelast(1)
            case False:
                return stream[:3] | pipe.list() | pipe.takelast(1)

    xs = (
        stream.range(10, interval=0.1)
        | split.pipe(is_even)
        | pipe.starmap(split_stream)
        | pipe.flatten()
        | pipe.list()
    )
    assert await xs == [[1, 3, 5], 20]

Here the key function is simply whether the item is even or odd. Then starmap can be used to apply specific stream operations depending on this predicate. For the sake of this example, the even numbers will summed together while the first 3 odd numbers are gathered as a list. Then both results are produced using the advanced flatten operator.

Here's a diagram of the corresponding pipeline:

graph TD;
    A(range) --> B(split);
    B --> C(starmap);
    C --> D(accumulate);
    D --> E(takelast);
    C --> F(take);
    F --> G(list);
    G --> H(takelast);
    E --> I(flatten);
    H --> I;
    I --> J(list);  

Does that correspond to your use case?

parkerbjur commented 6 months ago

Thanks! ya this is awesome.

reuben commented 1 month ago

What would it look like if I wanted to zip the resulting splits back together, rather than just flatten them as they come along? I've been trying to figure out how to achieve this but keep running into walls.

vxgmichel commented 1 month ago

@reuben

What would it look like if I wanted to zip the resulting splits back together, rather than just flatten them as they come along?

Interesting, I haven't thought about this use case before. As it turns out, this operator is not trivial to implement since it deals with a stream of streams. Here's a possible implementation:


from typing import AsyncIterable, AsyncIterator, Union, cast, TypeVar
from aiostream import pipable_operator
from aiostream.core import Streamer
from aiostream.manager import StreamerManager

T = TypeVar("T")

async def higherzip(
    source: AsyncIterable[AsyncIterable[T]],
    n: int = 2,
) -> AsyncIterator[tuple[T, ...]]:

    # Safe context
    async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:

        main_streamer: Streamer[AsyncIterable[T] | T] | None = (
            await manager.enter_and_create_task(source)
        )
        substreamers: list[Streamer[AsyncIterable[T] | T]] = []
        current_item: dict[Streamer[AsyncIterable[T] | T], T] = {}

        # Loop over events
        while manager.tasks:

            # Wait for next event
            streamer, task = await manager.wait_single_event(list(manager.tasks))

            # Get result
            try:
                result = task.result()

            # End of stream
            except StopAsyncIteration:
                # Main streamer is finished
                if streamer is main_streamer:
                    return

                # A substreamer is finished
                else:
                    await manager.clean_streamer(streamer)
                    return

            # Process result
            else:

                # Setup a new source
                if streamer is main_streamer:

                    if len(substreamers) == n:
                        raise ValueError("Too many substreamers")

                    assert isinstance(result, AsyncIterable)
                    result = cast(AsyncIterable[T], result)
                    substreamers.append(await manager.enter_and_create_task(result))
                    manager.create_task(streamer)

                # Yield the result
                else:
                    result = cast(T, result)
                    assert streamer not in current_item
                    current_item[streamer] = result

                    if len(current_item) < n:
                        continue

                    item = tuple(cast(T, current_item[substreamer]) for substreamer in substreamers)
                    yield item
                    current_item = {}

                    # Re-schedule the substreamers
                    for substreamer in substreamers:
                        manager.create_task(substreamer)

I'm not even sure how to call it though, and how it would fit within the existing operators. For instance, should zip be built on top of it, the same way map is built on top of flatmap and concatmap? I have to think about it :thinking:

reuben commented 1 month ago

Thanks! It didn't occur to me to key the incoming results on the streamer itself. I slapped a @pipable_operator in your higherzip and it worked for me. In the mean time I ended up cobbling together this modification of your split function:

import asyncio
from typing import Any, AsyncIterable, AsyncIterator, TypeVar

import aiostream
from aiostream import pipable_operator
from aiostream.aiter_utils import AsyncExitStack
from aiostream.core import Streamer, streamcontext
from anyio import BrokenResourceError, create_memory_object_stream
from anyio.abc import ObjectSendStream

T = TypeVar("T")
K = TypeVar("K")

@pipable_operator
async def tee(
    source: AsyncIterable[T], n: int, max_buffer_size: float = 0
) -> AsyncIterator[tuple[Streamer[T], ...]]:
    mapping: dict[int, ObjectSendStream[T]] = {}
    async with AsyncExitStack() as stack:
        async with streamcontext(source) as source:
            receivers = []
            for key in range(n):
                sender, receiver = create_memory_object_stream[T](
                    max_buffer_size=max_buffer_size
                )
                mapping[key] = await stack.enter_async_context(sender)
                receivers.append(streamcontext(receiver))
            yield tuple(receivers)  # single yield of multiple streams

            async for chunk in source:
                for key in range(n):
                    try:
                        await mapping[key].send(chunk)
                    except BrokenResourceError:
                        pass

async def main():
    def into_zipped_stream(
        stream_one: Streamer[int],
        stream_two: Streamer[int],
        *_: Any,
    ):
        return aiostream.stream.zip(stream_one | aiostream.pipe.spaceout(1), stream_two)

    async def merge_entries(entries: tuple[int, int], *_):
        entry_one, entry_two = entries
        return {
            "one": entry_one,
            "two": entry_two,
        }

    pipeline = aiostream.stream.iterate(range(10))
    pipeline = (
        pipeline
        | tee.pipe(2)
        | aiostream.pipe.starmap(into_zipped_stream)
        | aiostream.pipe.flatten()
        | aiostream.pipe.map(merge_entries)
    )

    async with pipeline.stream() as stream:
        async for item in stream:
            print(item)

asyncio.run(main())

This also works and I can zip the results in the map function with the existing zip function, but as soon as I try to do something like this:

    @pipable_operator
    async def unbatch(source: AsyncIterable[list[T]]) -> AsyncIterator[T]:
        async with streamcontext(source) as streamer:
            async for batch in streamer:
                for item in batch:
                    yield item

    def into_zipped_stream(
        stream_one: Streamer[int],
        stream_two: Streamer[int],
        *_: Any,
    ):
        return aiostream.stream.zip(
            stream_one | aiostream.pipe.chunks(8) | unbatch,
            stream_two | aiostream.pipe.chunks(4) | unbatch
        )

…the pipeline hangs. If the batch sizes match in the two sub-streams, it doesn't hang. If I batch before the tee and unbatch after the merge, it also doesn't hang. But either of those options don't allow me to use different batch sizes.

What I'm trying to do is to split my stream into many, then each sub-stream gets batched and processed by different subprocessors (with different scaling characteristics), then finally the results are zipped back together into a full entry containing all results from all subprocessors. I didn't include the processing function in the snippet above but you can imagine an aiostream.pipe.map call between chunks and unbatch on each sub-stream.