sysid / sse-starlette

BSD 3-Clause "New" or "Revised" License
505 stars 36 forks source link

Custom async generators support #42

Closed MaksimZayats closed 1 year ago

MaksimZayats commented 1 year ago

Hi!

The EventSourceResponse class won't work with custom async generators. (like class Stream in the example below)

But if add one more check here: https://github.com/sysid/sse-starlette/blob/0b863dc1b60a3d0ca6f26790b5b907329f7161ff/sse_starlette/sse.py#L155

Like this:

if inspect.isasyncgen(content) or isinstance(content, AsyncIterable)

Or even like this:

if isinstance(content, AsyncIterable)

Code below will work.

The code:

import asyncio

from fastapi import FastAPI, Depends
from sse_starlette import EventSourceResponse, ServerSentEvent
from starlette import status

class Stream:
    def __init__(self) -> None:
        self._queue = asyncio.Queue[ServerSentEvent]()

    def __aiter__(self) -> "Stream":
        return self

    async def __anext__(self) -> ServerSentEvent:
        return await self._queue.get()

    async def asend(self, value: ServerSentEvent) -> None:
        await self._queue.put(value)

app = FastAPI()

_stream = Stream()
app.dependency_overrides[Stream] = lambda: _stream

@app.get("/sse")
async def sse(stream: Stream = Depends()) -> EventSourceResponse:
    return EventSourceResponse(stream)

@app.post("/message", status_code=status.HTTP_201_CREATED)
async def send_message(message: str, stream: Stream = Depends()) -> None:
    await stream.asend(
        ServerSentEvent(data=message)
    )

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="127.0.0.1", port=8080)
sysid commented 1 year ago

merge. Thank you!