DontPanicO / fastapi-distributed-websocket

A library to implement websocket for distibuted system based on FastAPI.
MIT License
54 stars 10 forks source link

Update `WebSocketProxy.__call__` to use `asyncio.TaskGroup` #16

Open DontPanicO opened 1 year ago

DontPanicO commented 1 year ago

Feature or Enhancement

Move from asyncio.gather to asyncio.TaskGroup in WebSocketProxy.__call__.

Pitch

asyncio.TaskGroup (only available wih Python >= 3.11) has a more structured cancellation logic and (as stated from python docs) it should be preferred over asyncio.gather when there are no reasons to use one of the two over the other.

Actually, the implementation is:

async def _forward(
    client: WebSocket, target: websockets.WebSocketClientProtocol
) -> None:
    ...

async def _reverse(
    client: WebSocket, target: websockets.WebSocketClientProtocol
) -> None:
    ...

class WebSocketProxy:
    ...

    async def __call__(self) -> None:
        async with websockets.connect(self._server_endpoint) as target:
            self._forward_task = asyncio.create_task(
                _forward(self._client, target)
            )
            self._reverse_task = asyncio.create_task(
                _reverse(self._client, target)
            )
            await asyncio.gather(self._forward_task, self._reverse_task)

With asyncio.TaskGroup it'd be like:

...

    async def __call__(self) -> None:
        async with asyncio.TaskGroup() as tg:
            async with websockets.connect(self._server_endpoint) as target:
                self._forward_task = tg.create_task(
                    _forward(self._client, target)
                )
                self._reverse_task = tg.create_task(
                    _reverse(self._client, target)
                )

Entering websockets.connect in the taskgroup context ensures that if any failure with target occurs, our child tasks (and the partent too) would properly cancel.