sysid / sse-starlette

BSD 3-Clause "New" or "Revised" License
545 stars 37 forks source link

asyncio.InvalidStateError on async generator when connection is closed with FastAPI #48

Closed DurandA closed 1 year ago

DurandA commented 1 year ago

I am using a simple publish/subscribe pattern with FastAPI in order to broadcast data to clients using SSE:

import asyncio
from fastapi import FastAPI, Request
from sse_starlette.sse import EventSourceResponse

class PubSub:
    def __init__(self):
        self.waiter = asyncio.Future()

    def publish(self, value):
        waiter, self.waiter = self.waiter, asyncio.Future()
        waiter.set_result((value, self.waiter))

    async def subscribe(self):
        waiter = self.waiter
        while True:
            value, waiter = await waiter
            yield value

    __aiter__ = subscribe

pubsub = PubSub()

async def ticker(pubsub):
    counter = 0
    while True:
        pubsub.publish(counter)
        counter += 1
        await asyncio.sleep(1)

app = FastAPI()

@app.on_event("startup")
async def on_startup():    
    asyncio.create_task(ticker(pubsub), name='my_task')

@app.get('/stream')
async def message_stream(request: Request):
    async def event_publisher():
        try:
            while True:
                async for event in pubsub:
                    yield dict(data=event)
        except asyncio.CancelledError as e:
            print(f"Disconnected from client (via refresh/close) {request.client}")
            # Do any other cleanup, if any
            raise e
    return EventSourceResponse(event_publisher())

However, the task "my_task" is somehow killed as soon as the first client disconnects:

Task exception was never retrieved
future: <Task finished name='my_task' coro=<ticker() done, defined at /home/duranda/devel/fastapi-pubsub/main.py:51> exception=InvalidStateError('invalid state')>
Traceback (most recent call last):
  File "/home/duranda/devel/fastapi-pubsub/main.py", line 54, in ticker
    pubsub.publish(counter)
  File "/home/duranda/devel/fastapi-pubsub/main.py", line 38, in publish
    waiter.set_result((value, self.waiter))
asyncio.exceptions.InvalidStateError: invalid state

I also tried with other patterns, such as using AsyncIteratorObserver from aioreactive with the same result: the task linked to the async iterator ends up with an InvalidStateError.

sysid commented 1 year ago

@DurandA I am not sure whether I understand your post properly, but I do not see the direct relation with sse-starlette. If you can be a bit more concrete, please feel free to reopen.

DurandA commented 1 year ago

The issue was due to the EventSourceResponse cancelling the task from the asynchronous iterator.

The fix was to "shield" the iterator as follows:

async def event_publisher():
    aiter = pubsub.__aiter__()
    try:
        while True:
            task = asyncio.create_task(aiter.__anext__())
            event = await asyncio.shield(task)
            yield dict(data=event)
    except asyncio.CancelledError as e:
        print(f"Disconnected from client (via refresh/close) {request.client}")
        # Do any other cleanup, if any
        raise e

I suppose that the task is cancelled here: https://github.com/sysid/sse-starlette/blob/30ef55c08a8b1512b625752be136a1ea67df6030/sse_starlette/sse.py#L229-L236

sysid commented 1 year ago

Thanks for sharing your experience. I am glad that you could find a solution.

Is there anything that can be improved on sse-starlette's side?