taskiq-python / taskiq

Distributed task queue with full async support
MIT License
689 stars 44 forks source link

How to cancel sending a task using middleware #327

Open Bohdan-Ilchyshyn opened 1 month ago

Bohdan-Ilchyshyn commented 1 month ago

I create singleton middleware. It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func. How to do it correctly? Return None or raise exception?

Middleware code


import inspect
import time
from hashlib import md5
from typing import Any, Coroutine, Union

from cashews import cache
from loguru import logger
from orjson import orjson
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult

class SingletonMiddleware(TaskiqMiddleware):
    SINGLETON_LABEL = "singleton"
    UNIQUE_ON_LABEL = "unique_on"
    LOCK_EXPIRE_LABEL = "lock_expire"
    KEY_PREFIX = "TKQ_SINGLETON_LOCK_"

    def __init__(
            self,
            default_lock_expire: int = 60,
    ) -> None:
        super().__init__()
        self.default_lock_expire = default_lock_expire

    def pre_send(
        self,
        message: "TaskiqMessage",
    ) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
        if self.is_singleton_task(message):
            return self.lock_and_run(message)
        else:
            return message

    async def post_execute(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    async def on_error(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
        exception: BaseException,
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    def is_singleton_task(self, message: "TaskiqMessage") -> bool:
        return self.SINGLETON_LABEL in message.labels

    @staticmethod
    async def unlock(lock_key: str, task_id: str) -> bool:
        return await cache.unlock(lock_key, task_id)

    @staticmethod
    async def lock(lock_key: str, task_id: str, expire: int) -> bool:
        return await cache.set_lock(key=lock_key, value=task_id, expire=expire)

    @staticmethod
    async def locked(lock_key: str) -> bool:
        return await cache.is_locked(key=lock_key)

    @staticmethod
    async def get_existing_task_id(lock_key: str) -> int:
        return await cache.get(key=lock_key)

    async def lock_and_run(self, message: TaskiqMessage) -> TaskiqMessage | None:
        lock_acquired = await self.acquire_lock(message)

        if lock_acquired:
            return message
        else:
            lock_key = self.generate_lock(message)
            existing_task_id = self.get_existing_task_id(lock_key)
            logger.warning(f"Attempted to queue a duplicate of task ID {existing_task_id}")
            # raise SendTaskError()
            return None

    async def get_lock_expire(self, message: "TaskiqMessage") -> int:
        if self.LOCK_EXPIRE_LABEL in message.labels:
            return message.labels[self.LOCK_EXPIRE_LABEL]
        elif 'timeout' in message.labels:
            task_timeout = int(message.labels['timeout'])
            task_timeout += 5 * 60
            return task_timeout
        else:
            return self.default_lock_expire

    async def release_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        unlocked = await self.unlock(lock_key, message.task_id)
        return unlocked

    async def acquire_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        lock_expire = await self.get_lock_expire(message)
        locked = await self.lock(lock_key, message.task_id, lock_expire)
        return locked

    @staticmethod
    def generate_lock_key(task_name: str, task_args: list, task_kwargs: dict, key_prefix: str) -> str:
        str_args = str(orjson.dumps(task_args, option=orjson.OPT_SORT_KEYS))
        str_kwargs = str(orjson.dumps(task_kwargs, option=orjson.OPT_SORT_KEYS))
        task_hash = md5((task_name + str_args + str_kwargs).encode()).hexdigest()
        return key_prefix + task_hash

    def generate_lock(self, message: "TaskiqMessage") -> str:
        task = self.broker.find_task(message.task_name)

        if unique_on := message.labels.get('unique_on'):
            if isinstance(unique_on, str):
                unique_on = [unique_on]

            sig = inspect.signature(task.original_func)
            bound = sig.bind(*message.args, **message.kwargs).arguments

            unique_args = []
            unique_kwargs = {key: bound[key] for key in unique_on}

        else:
            unique_args = message.args
            unique_kwargs = message.kwargs

        lock_key = self.generate_lock_key(
            task_name=str(message.task_name),
            task_args=unique_args,
            task_kwargs=unique_kwargs,
            key_prefix=self.KEY_PREFIX,
        )

        return lock_key

Task example

@broker.task(
    singleton=True,
    unique_on=['id', 'name']
)
async def my_singleton_task(id: str, name: str) -> None:
    pass