taskiq-python / taskiq

Distributed task queue with full async support
MIT License
686 stars 44 forks source link

Question: How to cancel a running task #305

Open realitix opened 3 months ago

realitix commented 3 months ago

Hello, I have a special case to manage and I don't see how to do it. At a given moment, I need to know if a task (I have its ID) is actually in progress on a worker, is that possible?

realitix commented 3 months ago

After further consideration, what I am looking for is the ability to stop an ongoing task. Is it possible ?

s3rius commented 2 months ago

Currently there's no such functionality, but I really do want to define an interface to setup such task interruptors.

I'm open for discussion on that.

realitix commented 2 months ago

I developed a custom receiver for that. If someone wants to do it with redis, here the code:

import asyncio
import uuid
from typing import Any, AsyncGenerator, cast

import anyio
from loguru import logger
from redis.asyncio import Redis
from taskiq.abc.broker import AckableMessage
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.receiver.receiver import QUEUE_DONE, Receiver
from taskiq_redis import ListQueueBroker

# ruff: noqa: ANN401,BLE001,C901
# pylint: skip-file

CANCELLER_KEY = "__cancel_task_id__"

class CancellableListQueueBroker(ListQueueBroker):
    def __init__(
        self,
        *args: Any,
        queue_name_cancel: str = "taskiq_cancel",
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.queue_name_cancel = queue_name_cancel

    async def listen_canceller(self) -> AsyncGenerator[bytes, None]:
        async with Redis(connection_pool=self.connection_pool) as redis_conn:
            redis_pubsub_channel = redis_conn.pubsub()
            await redis_pubsub_channel.subscribe(self.queue_name_cancel)
            async for message in redis_pubsub_channel.listen():
                if not message:
                    continue
                if message["type"] != "message":
                    logger.debug("Received non-message from redis: {}", message)
                    continue
                yield message["data"]

    async def cancel_task(self, task_id: uuid.UUID) -> None:
        taskiq_message: TaskiqMessage = self._prepare_message(task_id)
        broker_message: BrokerMessage = self.formatter.dumps(taskiq_message)
        async with Redis(connection_pool=self.connection_pool) as redis_conn:
            await redis_conn.publish(self.queue_name_cancel, broker_message.message)

    def _prepare_message(self, task_id: uuid.UUID) -> TaskiqMessage:
        return TaskiqMessage(
            task_id=self.id_generator(),
            task_name="canceller",
            labels={},
            labels_types={},
            args=[],
            kwargs={CANCELLER_KEY: task_id.hex},
        )

class CancellableReceiver(Receiver):
    def __init__(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.tasks: set[asyncio.Task[Any]] = set()

    def parse_message(self, message: bytes | AckableMessage) -> TaskiqMessage | None:
        message_data = message.data if isinstance(message, AckableMessage) else message
        try:
            taskiq_msg = self.broker.formatter.loads(message=message_data)
            taskiq_msg.parse_labels()
        except Exception as exc:
            logger.warning(
                "Cannot parse message: %s. Skipping execution.\n %s",
                message_data,
                exc,
                exc_info=True,
            )
            return None
        return taskiq_msg

    async def listen(self) -> None:  # pragma: no cover
        if self.run_startup:
            await self.broker.startup()
        logger.info("Listening started.")
        queue: asyncio.Queue[bytes | AckableMessage] = asyncio.Queue()

        async with anyio.create_task_group() as gr:
            gr.start_soon(self.prefetcher, queue)
            gr.start_soon(self.runner, queue)
            gr.start_soon(self.runner_canceller)

        if self.on_exit is not None:
            self.on_exit(self)

    async def runner_canceller(
        self,
    ) -> None:
        def cancel_task(task_id: str) -> None:
            for task in self.tasks:
                if task.get_name() == task_id:
                    if task.cancel():
                        logger.info("Cancelling task {}", task_id)
                    else:
                        logger.warning("Cannot cancel task {}", task_id)

        iterator = cast(CancellableListQueueBroker, self.broker).listen_canceller()
        while True:
            try:
                message = await iterator.__anext__()
                taskiq_msg = self.parse_message(message)

                if not taskiq_msg:
                    continue

                if CANCELLER_KEY in taskiq_msg.kwargs:
                    cancel_task(taskiq_msg.kwargs[CANCELLER_KEY])
            except asyncio.CancelledError:
                break
            except StopAsyncIteration:
                break

    async def runner(
        self,
        queue: asyncio.Queue[bytes | AckableMessage],
    ) -> None:
        def task_cb(task: asyncio.Task[Any]) -> None:
            self.tasks.discard(task)
            if self.sem is not None:
                self.sem.release()

        while True:
            if self.sem is not None:
                await self.sem.acquire()

            self.sem_prefetch.release()
            message = await queue.get()
            if message is QUEUE_DONE:
                break

            taskiq_msg = self.parse_message(message)
            if not taskiq_msg:
                continue

            task = asyncio.create_task(
                self.callback(message=message, raise_err=False),
                name=str(taskiq_msg.task_id),
            )
            self.tasks.add(task)
            task.add_done_callback(task_cb)