Open parkerbjur opened 6 months ago
Hmm interesting problem!
It turns out that aiostream can handle streams of streams using the advanced operators, so that's probably what you're after here. The missing part is the ability to split a stream into a steam of streams where items are forwarded depending on a given predicate.
Here's an example of a split
operator that would do just that:
from typing import AsyncIterable, TypeVar, Callable, AsyncIterator
from aiostream import pipable_operator, stream, pipe
from aiostream.core import streamcontext, Streamer
from aiostream.aiter_utils import AsyncExitStack
from anyio import create_memory_object_stream, BrokenResourceError
from anyio.abc import ObjectSendStream
T = TypeVar("T")
K = TypeVar("K")
@pipable_operator
async def split(
source: AsyncIterable[T], key_function: Callable[[T], K], max_buffer_size: float = 0
) -> AsyncIterator[tuple[K, Streamer[T]]]:
mapping: dict[K, ObjectSendStream[T]] = {}
async with AsyncExitStack() as stack:
async with streamcontext(source) as source:
async for chunk in source:
key = key_function(chunk)
if key not in mapping:
sender, receiver = create_memory_object_stream[T](
max_buffer_size=max_buffer_size
)
mapping[key] = await stack.enter_async_context(sender)
yield key, streamcontext(receiver)
try:
await mapping[key].send(chunk)
except BrokenResourceError:
pass
Note how it uses a key function to tell where each produced item belongs. Here's an example of this operator being used:
@pytest.mark.asyncio
async def test_split():
def is_even(x: int) -> bool:
return x % 2 == 0
def split_stream(
key: bool, stream: Streamer[int], *_
) -> AsyncIterable[int | list[int]]:
match key:
case True:
return stream | pipe.accumulate(initializer=0) | pipe.takelast(1)
case False:
return stream[:3] | pipe.list() | pipe.takelast(1)
xs = (
stream.range(10, interval=0.1)
| split.pipe(is_even)
| pipe.starmap(split_stream)
| pipe.flatten()
| pipe.list()
)
assert await xs == [[1, 3, 5], 20]
Here the key function is simply whether the item is even or odd. Then starmap
can be used to apply specific stream operations depending on this predicate. For the sake of this example, the even numbers will summed together while the first 3 odd numbers are gathered as a list. Then both results are produced using the advanced flatten
operator.
Here's a diagram of the corresponding pipeline:
graph TD;
A(range) --> B(split);
B --> C(starmap);
C --> D(accumulate);
D --> E(takelast);
C --> F(take);
F --> G(list);
G --> H(takelast);
E --> I(flatten);
H --> I;
I --> J(list);
Does that correspond to your use case?
Thanks! ya this is awesome.
What would it look like if I wanted to zip the resulting splits back together, rather than just flatten them as they come along? I've been trying to figure out how to achieve this but keep running into walls.
@reuben
What would it look like if I wanted to zip the resulting splits back together, rather than just flatten them as they come along?
Interesting, I haven't thought about this use case before. As it turns out, this operator is not trivial to implement since it deals with a stream of streams. Here's a possible implementation:
from typing import AsyncIterable, AsyncIterator, Union, cast, TypeVar
from aiostream import pipable_operator
from aiostream.core import Streamer
from aiostream.manager import StreamerManager
T = TypeVar("T")
async def higherzip(
source: AsyncIterable[AsyncIterable[T]],
n: int = 2,
) -> AsyncIterator[tuple[T, ...]]:
# Safe context
async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:
main_streamer: Streamer[AsyncIterable[T] | T] | None = (
await manager.enter_and_create_task(source)
)
substreamers: list[Streamer[AsyncIterable[T] | T]] = []
current_item: dict[Streamer[AsyncIterable[T] | T], T] = {}
# Loop over events
while manager.tasks:
# Wait for next event
streamer, task = await manager.wait_single_event(list(manager.tasks))
# Get result
try:
result = task.result()
# End of stream
except StopAsyncIteration:
# Main streamer is finished
if streamer is main_streamer:
return
# A substreamer is finished
else:
await manager.clean_streamer(streamer)
return
# Process result
else:
# Setup a new source
if streamer is main_streamer:
if len(substreamers) == n:
raise ValueError("Too many substreamers")
assert isinstance(result, AsyncIterable)
result = cast(AsyncIterable[T], result)
substreamers.append(await manager.enter_and_create_task(result))
manager.create_task(streamer)
# Yield the result
else:
result = cast(T, result)
assert streamer not in current_item
current_item[streamer] = result
if len(current_item) < n:
continue
item = tuple(cast(T, current_item[substreamer]) for substreamer in substreamers)
yield item
current_item = {}
# Re-schedule the substreamers
for substreamer in substreamers:
manager.create_task(substreamer)
I'm not even sure how to call it though, and how it would fit within the existing operators. For instance, should zip
be built on top of it, the same way map
is built on top of flatmap
and concatmap
? I have to think about it :thinking:
Thanks! It didn't occur to me to key the incoming results on the streamer itself. I slapped a @pipable_operator
in your higherzip and it worked for me. In the mean time I ended up cobbling together this modification of your split function:
import asyncio
from typing import Any, AsyncIterable, AsyncIterator, TypeVar
import aiostream
from aiostream import pipable_operator
from aiostream.aiter_utils import AsyncExitStack
from aiostream.core import Streamer, streamcontext
from anyio import BrokenResourceError, create_memory_object_stream
from anyio.abc import ObjectSendStream
T = TypeVar("T")
K = TypeVar("K")
@pipable_operator
async def tee(
source: AsyncIterable[T], n: int, max_buffer_size: float = 0
) -> AsyncIterator[tuple[Streamer[T], ...]]:
mapping: dict[int, ObjectSendStream[T]] = {}
async with AsyncExitStack() as stack:
async with streamcontext(source) as source:
receivers = []
for key in range(n):
sender, receiver = create_memory_object_stream[T](
max_buffer_size=max_buffer_size
)
mapping[key] = await stack.enter_async_context(sender)
receivers.append(streamcontext(receiver))
yield tuple(receivers) # single yield of multiple streams
async for chunk in source:
for key in range(n):
try:
await mapping[key].send(chunk)
except BrokenResourceError:
pass
async def main():
def into_zipped_stream(
stream_one: Streamer[int],
stream_two: Streamer[int],
*_: Any,
):
return aiostream.stream.zip(stream_one | aiostream.pipe.spaceout(1), stream_two)
async def merge_entries(entries: tuple[int, int], *_):
entry_one, entry_two = entries
return {
"one": entry_one,
"two": entry_two,
}
pipeline = aiostream.stream.iterate(range(10))
pipeline = (
pipeline
| tee.pipe(2)
| aiostream.pipe.starmap(into_zipped_stream)
| aiostream.pipe.flatten()
| aiostream.pipe.map(merge_entries)
)
async with pipeline.stream() as stream:
async for item in stream:
print(item)
asyncio.run(main())
This also works and I can zip the results in the map function with the existing zip
function, but as soon as I try to do something like this:
@pipable_operator
async def unbatch(source: AsyncIterable[list[T]]) -> AsyncIterator[T]:
async with streamcontext(source) as streamer:
async for batch in streamer:
for item in batch:
yield item
def into_zipped_stream(
stream_one: Streamer[int],
stream_two: Streamer[int],
*_: Any,
):
return aiostream.stream.zip(
stream_one | aiostream.pipe.chunks(8) | unbatch,
stream_two | aiostream.pipe.chunks(4) | unbatch
)
…the pipeline hangs. If the batch sizes match in the two sub-streams, it doesn't hang. If I batch before the tee
and unbatch
after the merge, it also doesn't hang. But either of those options don't allow me to use different batch sizes.
What I'm trying to do is to split my stream into many, then each sub-stream gets batched and processed by different subprocessors (with different scaling characteristics), then finally the results are zipped back together into a full entry containing all results from all subprocessors. I didn't include the processing function in the snippet above but you can imagine an aiostream.pipe.map
call between chunks
and unbatch
on each sub-stream.
Hey! this isn't really an issue with the library but more of a request for help so sorry if this isn't the right place. I am trying to write a transcriber that will take in multiple different tracks and emit transcriptions based on the track. so the input events would look something like
I've already made a pipeable operator that can take a stream of bytes and output transcriptions, but I'm not sure of the best way to separate these streams and pass them to the appropriate transcriber. (I also have a NewTrack event but wasn't sure if that would be a distraction)
This seems to be a pattern I run into a lot where I have a stream that needs to be separated into multiple streams and a handler needs to be created when there is a new stream. Would love some insight into how to handle this. I can also provide any additional information or clarification. Thanks!