redis / redis-py

Redis Python client
MIT License
12.65k stars 2.52k forks source link

Async redis subscription inside websocket doesn't shut down properly #2523

Open wholmen opened 1 year ago

wholmen commented 1 year ago

Version: 4.4.0

Platform: Python 3.11.0, ubuntu

Description:

I am using async redis to subscribe to a topic within a fastapi websocket connection. It works fine, but I cannot make it shut down properly. i.e. when the server shuts down, the await get_message is not stopping.

I have tried to use SIGTERM, but it seems the reader-function doesn't yield to SIGTERM, because it gives error

INFO:     Waiting for background tasks to complete. (CTRL+C to force quit)

Before handling the SIGTERM command.

Does anyone have a suggestion for where to start looking?

@router.websocket("/ws")
async def websocket_labtest(websocket: WebSocket):
    await websocket.accept()

    async_redis_client: aioredis.client.Redis = aioredis.from_url(
        url=settings.redis_url, port=settings.redis_port, password=settings.redis_password, decode_responses=True
    )
    channelname: str = f"channel:{events.LabtestAddedEvent.__name__}"
    redis_channel = async_redis_client.pubsub()

    async def reader(channel: aioredis.client.PubSub):
        while True:
            try:
                async with async_timeout.timeout(1):
                    message = await channel.get_message(ignore_subscribe_messages=True)
                    if message is not None:
                        response = json.loads(message["data"])
                        await websocket.send_json(response)
                    await asyncio.sleep(0.01)

            except WebSocketDisconnect:
                print("hit websocket disconnect")
            except asyncio.TimeoutError:
                pass
            except asyncio.CancelledError:
                print("Cancelled")
                break
            except aioredis.PubSubError:
                print("Pubsub error")
                break

    async with redis_channel as p:
        await p.subscribe(channelname)
        await reader(p)
        await p.unsubscribe(channelname)

    await redis_channel.close()
Andrew-Chen-Wang commented 1 year ago

The problem is the block in parse_response. Current solution is to pass a small timeout in get_message

Still not working. Anyone have a solution? This seems to make PubSub completely unusable if asyncio.CancelledError isn't raised

tomer555 commented 1 year ago

+1

Andrew-Chen-Wang commented 1 year ago

This has been working for me. There are some random variables like "readers"; just copy pasted from a project so parse thru what you need:

import asyncio
from collections.abc import Callable, Coroutine
from typing import Literal

from redis.asyncio.client import PubSub

from app.utils.redis.subscriber.typing import ChatReadersT, ReadersT, ReaderT

async def reader(channel: redis.client.PubSub, readers: ReadersT):
    try:
        while True:
            # https://github.com/redis/redis-py/issues/2523
            message: ChannelMessage = await channel.get_message(
                ignore_subscribe_messages=True, timeout=0.5
            )
            if message is None:
                continue
            _channel = message["channel"].decode().split(":", 1)[1]
            if ":" not in _channel:
                _channel = int(_channel)
            wsr = readers.get().get(_channel, {}).items()
            data: dict = orjson.loads(message["data"])
            reader_id = data.pop("id", None)
            data = data["data"]
            [w.messages.put_nowait(data) for k, w in wsr if k != reader_id]
    except RedisConnectionError:
        pass
    finally:
        pass

class BaseRedisConnection:
    def __init__(
        self,
        *,
        channel: str,
        include_wildcard: bool = True,
        reader: Callable,
        subscription_type: Literal["psubscribe", "subscribe"] = "psubscribe",
    ):
        """
        Handler for PubSub connection

        :param channel: PubSub channel name
        :param include_wildcard: Whether to make PubSub channel name include a wildcard.
        Applicable only to when subscription_type is "psubscribe"
        :param reader: An infinite loop callable that reads from the PubSub channel
        :param subscription_type: the type of PubSub subscription to use
        """
        self.reader_task: asyncio.Task | None = None
        self.r = None
        self.pubsub: PubSub | None = None
        self.reader = reader
        self.subscription_type = subscription_type
        if subscription_type == "psubscribe" and include_wildcard:
            self.channel = f"{channel}:*"
        else:
            self.channel = channel

    async def create_reader(self, pubsub: PubSub) -> Coroutine:
        raise NotImplementedError("create_reader() must be implemented")

    async def start(self):
        self.pubsub = self.r.pubsub()
        await getattr(self.pubsub, self.subscription_type)(self.channel)
        self.reader_task = asyncio.create_task(await self.create_reader(self.pubsub))
        await self.reader_task

    async def close(self):
        await self.pubsub.close()
        if self.reader_task is not None:
            self.reader_task.cancel()

class RedisConnection(BaseRedisConnection):
    def __init__(
        self,
        *,
        channel: str,
        include_wildcard: bool = True,
        subscription_type: Literal["psubscribe", "subscribe"],
        reader: ReaderT,
        readers: ReadersT | ChatReadersT,
    ):
        super().__init__(
            channel=channel,
            include_wildcard=include_wildcard,
            subscription_type=subscription_type,
            reader=reader,
        )
        self.readers = readers

    async def create_reader(self, pubsub: PubSub) -> Coroutine:
        return self.reader(pubsub, self.readers)
nextmat commented 9 months ago

What worked for me is scheduling the reader method with create_task to start it:

reader_task = asyncio.create_task(reader(# ..args))

Then I have a lifespan method to make sure it gets shut down:

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Ensure redis gets closed cleanly when shutting down"""
    yield
    reader_task.cancel()

Register it like this:

app = FastAPI(
    lifespan=lifespan,
    # ...
)